Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Dec 14, 2024
1 parent 0d7743b commit 8021dbc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
6 changes: 4 additions & 2 deletions csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T
at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len,
unsigned int layout, int64_t cuda_stream);

void block_sparse_indices_to_vector_sparse_offsets(at::Tensor input, at::Tensor input_indptr,
at::Tensor output, at::Tensor output_indptr,
void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices,
at::Tensor block_sparse_indptr,
at::Tensor vector_sparse_offsets,
at::Tensor vector_sparse_indptr,
at::Tensor kv_len_arr, unsigned int stride_block,
unsigned int stride_n, unsigned int batch_size,
unsigned int block_size, int64_t cuda_stream);
Expand Down
6 changes: 4 additions & 2 deletions csrc/flashinfer_page_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T
at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len,
unsigned int layout, int64_t cuda_stream);

void block_sparse_indices_to_vector_sparse_offsets(at::Tensor input, at::Tensor input_indptr,
at::Tensor output, at::Tensor output_indptr,
void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices,
at::Tensor block_sparse_indptr,
at::Tensor vector_sparse_offsets,
at::Tensor vector_sparse_indptr,
at::Tensor kv_len_arr, unsigned int stride_block,
unsigned int stride_n, unsigned int batch_size,
unsigned int block_size, int64_t cuda_stream);
Expand Down
16 changes: 9 additions & 7 deletions csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,24 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T

void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices,
at::Tensor block_sparse_indptr,
at::Tensor vector_sparse_indices,
at::Tensor vector_sparse_offsets,
at::Tensor vector_sparse_indptr,
at::Tensor kv_len_arr, unsigned int stride_block,
unsigned int stride_n, unsigned int batch_size,
unsigned int block_size, int64_t cuda_stream) {
CHECK_INPUT(input);
CHECK_INPUT(input_indptr);
CHECK_INPUT(output);
CHECK_INPUT(output_indptr);
CHECK_INPUT(block_sparse_indices);
CHECK_INPUT(block_sparse_indptr);
CHECK_INPUT(vector_sparse_offsets);
CHECK_INPUT(vector_sparse_indptr);
CHECK_INPUT(kv_len_arr);

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);

cudaError_t status = BlockSparseIndicesToVectorSparseOffset(
static_cast<int32_t*>(input.data_ptr()), static_cast<int32_t*>(input_indptr.data_ptr()),
static_cast<int32_t*>(output.data_ptr()), static_cast<int32_t*>(output_indptr.data_ptr()),
static_cast<int32_t*>(block_sparse_indices.data_ptr()),
static_cast<int32_t*>(block_sparse_indptr.data_ptr()),
static_cast<int32_t*>(vector_sparse_offsets.data_ptr()),
static_cast<int32_t*>(vector_sparse_indptr.data_ptr()),
static_cast<int32_t*>(kv_len_arr.data_ptr()), stride_block, stride_n, batch_size, block_size,
stream);

Expand Down

0 comments on commit 8021dbc

Please sign in to comment.