Skip to content

Commit

Permalink
Use Torch's current stream for ops (#111)
Browse files Browse the repository at this point in the history
This PR makes PyTorch ops use the current Torch stream for kernel
execution. This allows compatibility with Torch CUDA Graphs and allows
the user to precisely set which stream to use in Python code using the
canonical PyTorch API.

Note: I believe I have found all cases where the stream should be set,
but I might have missed something.
  • Loading branch information
Yard1 authored Feb 8, 2024
1 parent 3fc62cb commit 6c6c44a
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 10 deletions.
8 changes: 6 additions & 2 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache(
num_kv_heads = k_padded.size(1);
}

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand All @@ -59,7 +60,7 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache(
/*tmp=*/tmp,
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout, RotaryMode(rotary_mode),
rope_scale, rope_theta);
rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ",
status);
return true;
Expand Down Expand Up @@ -88,6 +89,8 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
Expand Down Expand Up @@ -145,6 +148,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32);
CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse;
if (return_lse) {
Expand All @@ -163,7 +167,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
&handler_, static_cast<c_type*>(q.data_ptr()), /*q_rope_position=*/nullptr, paged_kv,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), num_qo_heads,
RotaryMode(rotary_mode), rope_scale, rope_theta, /*stream=*/nullptr);
RotaryMode(rotary_mode), rope_scale, rope_theta, /*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
});
Expand Down
10 changes: 8 additions & 2 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor work
// TODO(Zihao): support dispatching to different index data types.
CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

cudaError_t status = handler_.BeginForward(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
Expand Down Expand Up @@ -87,6 +89,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32);
CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand Down Expand Up @@ -114,7 +117,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
/*q_rope_position=*/nullptr, paged_kv, static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, rope_scale,
rope_theta,
/*stream=*/nullptr);
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithPagedKVCache failed with error code ",
cudaGetErrorString(status));
Expand Down Expand Up @@ -152,6 +155,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor wor
// TODO(Zihao): support dispatching to different index data types.
CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

cudaError_t status = handler_.BeginForward(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
Expand Down Expand Up @@ -191,6 +196,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32);
CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand All @@ -213,7 +219,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
num_kv_heads, rope_scale, rope_theta,
/*stream=*/nullptr);
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithRaggedKVCache failed with error ",
cudaGetErrorString(status));
Expand Down
9 changes: 6 additions & 3 deletions python/csrc/cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
unsigned int seq_len = v_a.size(0);
unsigned int num_heads = v_a.size(1);
unsigned int head_dim = v_a.size(2);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto v_merged = torch::empty_like(v_a, v_a.options());
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());

Expand All @@ -47,7 +48,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
MergeState(static_cast<c_type*>(v_a.data_ptr()), static_cast<float*>(s_a.data_ptr()),
static_cast<c_type*>(v_b.data_ptr()), static_cast<float*>(s_b.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()),
static_cast<float*>(s_merged.data_ptr()), seq_len, num_heads, head_dim);
static_cast<float*>(s_merged.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeState kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down Expand Up @@ -76,12 +77,13 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
unsigned int seq_len = v.size(0);
unsigned int num_heads = v.size(1);
unsigned int head_dim = v.size(2);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] {
cudaError_t status =
MergeStateInPlace(static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_other.data_ptr()),
static_cast<float*>(s_other.data_ptr()), seq_len, num_heads, head_dim);
static_cast<float*>(s_other.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand All @@ -103,14 +105,15 @@ std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
unsigned int num_heads = v.size(2);
unsigned int head_dim = v.size(3);
s = s.to(torch::kFloat32);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto v_merged = torch::empty({seq_len, num_heads, head_dim}, v.options());
auto s_merged = torch::empty({seq_len, num_heads}, s.options());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] {
cudaError_t status = MergeStates(
static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()), static_cast<float*>(s_merged.data_ptr()),
num_index_sets, seq_len, num_heads, head_dim);
num_index_sets, seq_len, num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeStates kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down
4 changes: 3 additions & 1 deletion python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
CHECK_EQ(append_value.size(1), num_heads);
CHECK_EQ(append_key.size(2), head_dim);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
paged_kv_t<page_storage, KV_LAYOUT, c_type, int32_t> paged_kv(
Expand All @@ -73,7 +75,7 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
cudaError_t status = AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
static_cast<int32_t*>(append_indptr.data_ptr()));
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"AppendPagedKVCache failed with error: ", cudaGetErrorString(status));
return true;
Expand Down
1 change: 1 addition & 0 deletions python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#pragma once
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>

#include "generated/dispatch.inc"

Expand Down
3 changes: 2 additions & 1 deletion python/csrc/single_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
num_kv_heads = k.size(0);
kv_len = k.size(1);
}
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
cudaError_t status = SingleDecodeWithKVCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<c_type*>(o.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim,
kv_layout, RotaryMode(rotary_mode), rope_scale, rope_theta, nullptr);
kv_layout, RotaryMode(rotary_mode), rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
return true;
Expand Down
3 changes: 2 additions & 1 deletion python/csrc/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
num_qo_heads = q.size(0);
}
CHECK(num_qo_heads % num_kv_heads == 0);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand All @@ -66,7 +67,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
static_cast<c_type*>(v.data_ptr()), static_cast<c_type*>(o.data_ptr()),
static_cast<float*>(tmp.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
num_kv_heads, qo_len, kv_len, rope_scale, rope_theta, nullptr);
num_kv_heads, qo_len, kv_len, rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"SinglePrefillWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
Expand Down

0 comments on commit 6c6c44a

Please sign in to comment.