Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Torch's current stream for ops #111

Merged
merged 1 commit into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache(
num_kv_heads = k_padded.size(1);
}

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand All @@ -59,7 +60,7 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache(
/*tmp=*/tmp,
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout, RotaryMode(rotary_mode),
rope_scale, rope_theta);
rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ",
status);
return true;
Expand Down Expand Up @@ -88,6 +89,8 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
Expand Down Expand Up @@ -145,6 +148,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32);
CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse;
if (return_lse) {
Expand All @@ -163,7 +167,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
&handler_, static_cast<c_type*>(q.data_ptr()), /*q_rope_position=*/nullptr, paged_kv,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), num_qo_heads,
RotaryMode(rotary_mode), rope_scale, rope_theta, /*stream=*/nullptr);
RotaryMode(rotary_mode), rope_scale, rope_theta, /*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
});
Expand Down
10 changes: 8 additions & 2 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor work
// TODO(Zihao): support dispatching to different index data types.
CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

cudaError_t status = handler_.BeginForward(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
Expand Down Expand Up @@ -87,6 +89,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32);
CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand Down Expand Up @@ -114,7 +117,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
/*q_rope_position=*/nullptr, paged_kv, static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, rope_scale,
rope_theta,
/*stream=*/nullptr);
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithPagedKVCache failed with error code ",
cudaGetErrorString(status));
Expand Down Expand Up @@ -152,6 +155,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor wor
// TODO(Zihao): support dispatching to different index data types.
CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

cudaError_t status = handler_.BeginForward(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
Expand Down Expand Up @@ -191,6 +196,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32);
CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand All @@ -213,7 +219,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
num_kv_heads, rope_scale, rope_theta,
/*stream=*/nullptr);
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithRaggedKVCache failed with error ",
cudaGetErrorString(status));
Expand Down
9 changes: 6 additions & 3 deletions python/csrc/cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
unsigned int seq_len = v_a.size(0);
unsigned int num_heads = v_a.size(1);
unsigned int head_dim = v_a.size(2);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto v_merged = torch::empty_like(v_a, v_a.options());
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());

Expand All @@ -47,7 +48,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
MergeState(static_cast<c_type*>(v_a.data_ptr()), static_cast<float*>(s_a.data_ptr()),
static_cast<c_type*>(v_b.data_ptr()), static_cast<float*>(s_b.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()),
static_cast<float*>(s_merged.data_ptr()), seq_len, num_heads, head_dim);
static_cast<float*>(s_merged.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeState kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down Expand Up @@ -76,12 +77,13 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
unsigned int seq_len = v.size(0);
unsigned int num_heads = v.size(1);
unsigned int head_dim = v.size(2);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] {
cudaError_t status =
MergeStateInPlace(static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_other.data_ptr()),
static_cast<float*>(s_other.data_ptr()), seq_len, num_heads, head_dim);
static_cast<float*>(s_other.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand All @@ -103,14 +105,15 @@ std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
unsigned int num_heads = v.size(2);
unsigned int head_dim = v.size(3);
s = s.to(torch::kFloat32);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto v_merged = torch::empty({seq_len, num_heads, head_dim}, v.options());
auto s_merged = torch::empty({seq_len, num_heads}, s.options());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] {
cudaError_t status = MergeStates(
static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()), static_cast<float*>(s_merged.data_ptr()),
num_index_sets, seq_len, num_heads, head_dim);
num_index_sets, seq_len, num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeStates kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down
4 changes: 3 additions & 1 deletion python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
CHECK_EQ(append_value.size(1), num_heads);
CHECK_EQ(append_key.size(2), head_dim);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
paged_kv_t<page_storage, KV_LAYOUT, c_type, int32_t> paged_kv(
Expand All @@ -73,7 +75,7 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
cudaError_t status = AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
static_cast<int32_t*>(append_indptr.data_ptr()));
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"AppendPagedKVCache failed with error: ", cudaGetErrorString(status));
return true;
Expand Down
1 change: 1 addition & 0 deletions python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#pragma once
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>

#include "generated/dispatch.inc"

Expand Down
3 changes: 2 additions & 1 deletion python/csrc/single_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
num_kv_heads = k.size(0);
kv_len = k.size(1);
}
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
cudaError_t status = SingleDecodeWithKVCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<c_type*>(o.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim,
kv_layout, RotaryMode(rotary_mode), rope_scale, rope_theta, nullptr);
kv_layout, RotaryMode(rotary_mode), rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
return true;
Expand Down
3 changes: 2 additions & 1 deletion python/csrc/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
num_qo_heads = q.size(0);
}
CHECK(num_qo_heads % num_kv_heads == 0);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand All @@ -66,7 +67,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
static_cast<c_type*>(v.data_ptr()), static_cast<c_type*>(o.data_ptr()),
static_cast<float*>(tmp.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
num_kv_heads, qo_len, kv_len, rope_scale, rope_theta, nullptr);
num_kv_heads, qo_len, kv_len, rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"SinglePrefillWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
Expand Down