Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMWrapper] Add stream argument in BeginForward #164

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1111,9 +1111,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation(
// compute max_num_pages_per_batch and new_batch_size
std::vector<IdType> page_indptr_h(batch_size + 1), num_pages(batch_size);
if (is_device_ptr(kv_indptr)) {
FLASHINFER_CUDA_CALL(cudaMemcpy(page_indptr_h.data(), kv_indptr,
sizeof(IdType) * (batch_size + 1),
cudaMemcpyDeviceToHost));
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(page_indptr_h.data(), kv_indptr,
sizeof(IdType) * (batch_size + 1),
cudaMemcpyDeviceToHost, stream));
} else {
page_indptr_h.assign(kv_indptr, kv_indptr + batch_size + 1);
}
Expand Down
16 changes: 13 additions & 3 deletions src/tvm_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,13 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q

void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward(
int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim) {
int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, TVMStreamHandle copy_stream) {
CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor";
size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8;
CHECK(handler_idx < max_num_handlers) << "The handler id must be less than " << max_num_handlers;
cudaStream_t original_stream = batch_prefill_paged_kv_handlers[handler_idx].GetCUDAStream();
batch_prefill_paged_kv_handlers[handler_idx].SetCUDAStream(
static_cast<cudaStream_t>(copy_stream));
DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, {
cudaError_t status = batch_prefill_paged_kv_handlers[handler_idx].BeginForward(
static_cast<void*>(workspace_buffer->data), workspace_size_in_bytes,
Expand All @@ -340,6 +343,7 @@ void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward(
LOG(FATAL) << "FlashInfer prefill BeginForward error " << cudaGetErrorString(status);
}
});
batch_prefill_paged_kv_handlers[handler_idx].SetCUDAStream(original_stream);
}

void _FlashInferAttentionPrefillWithPagedKVCacheEndForward(int64_t handler_idx) {
Expand Down Expand Up @@ -456,7 +460,7 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_
void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward(
int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* page_table_indptr,
DLTensor* last_page_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim,
int64_t page_size, int64_t pos_encoding_mode) {
int64_t page_size, int64_t pos_encoding_mode, TVMStreamHandle copy_stream) {
CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor";
size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8;
CHECK_LT(handler_idx, max_num_handlers)
Expand All @@ -467,6 +471,8 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward(
// leave a parameter for the input data type.
using dtype_in = half;
const uint32_t batch_size = page_table_indptr->shape[0] - 1;
cudaStream_t original_stream = batch_decode_handlers[handler_idx].GetCUDAStream();
batch_decode_handlers[handler_idx].SetCUDAStream(static_cast<cudaStream_t>(copy_stream));
DISPATCH_TVM_CUDA_IDTYPE(page_table_indptr->dtype, dtype_idx, {
cudaError_t status =
batch_decode_handlers[handler_idx]
Expand All @@ -479,6 +485,7 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward(
LOG(FATAL) << "FlashInfer decode BeginForward error " << cudaGetErrorString(status);
}
});
batch_decode_handlers[handler_idx].SetCUDAStream(original_stream);
}

void _FlashInferAttentionDecodeWithPagedKVCacheEndForward(int64_t handler_id) {
Expand Down Expand Up @@ -606,9 +613,11 @@ void _FlashInferAttentionPrefillWithRaggedKVCache(

void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward(
DLTensor* workspace_buffer, DLTensor* qo_indptr, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t head_dim) {
int64_t num_kv_heads, int64_t head_dim, TVMStreamHandle copy_stream) {
CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor";
size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8;
cudaStream_t original_stream = batch_prefill_ragged_kv_handler.GetCUDAStream();
batch_prefill_ragged_kv_handler.SetCUDAStream(static_cast<cudaStream_t>(copy_stream));

DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, {
cudaError_t status = batch_prefill_ragged_kv_handler.BeginForward(
Expand All @@ -619,6 +628,7 @@ void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward(
<< cudaGetErrorString(status);
}
});
batch_prefill_ragged_kv_handler.SetCUDAStream(original_stream);
}

void _FlashInferAttentionPrefillWithRaggedKVCacheEndForward() {
Expand Down