Skip to content

Commit

Permalink
bugfix: check gpu id in PyTorch APIs and use input tensor's gpu defau…
Browse files Browse the repository at this point in the history
…lt stream (#361)

This PR fixes #349 by using the default stream of input tensors' device
instead of the default stream of default device (which might be
different to input tensors' device). This PR also adds sanity check on
input tensors device id (all input tensors must be on the same GPU).
  • Loading branch information
yzh119 authored Jul 6, 2024
1 parent 3536198 commit 1b84fab
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 44 deletions.
12 changes: 9 additions & 3 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode,
float logits_soft_cap, torch::Tensor empty_q_data, torch::Tensor empty_kv_data) {
CHECK_INPUT(workspace_buffer);
// NOTE(zihao): not necessary to be CUDA tensor
CHECK_CONTIGUOUS(indptr);
CHECK_CONTIGUOUS(last_page_len);
CHECK_CONTIGUOUS(workspace_buffer);
CHECK_DIM(1, indptr);
CHECK_DIM(1, last_page_len);
CHECK_DIM(1, workspace_buffer);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto device = workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
handler_->SetCUDAStream(torch_current_stream);
indptr = indptr.to(torch::kCPU);
last_page_len = last_page_len.to(torch::kCPU);
Expand Down Expand Up @@ -116,6 +117,11 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
CHECK_INPUT(paged_kv_indptr);
CHECK_INPUT(paged_kv_indices);
CHECK_INPUT(paged_kv_last_page_len);
auto device = q.device();
CHECK_EQ(paged_kv_data.device(), device);
CHECK_EQ(paged_kv_indices.device(), device);
CHECK_EQ(paged_kv_indptr.device(), device);
CHECK_EQ(paged_kv_last_page_len.device(), device);
CHECK_DIM(3, q); // (B, H_qo, D)
CHECK_DIM(1, paged_kv_last_page_len); // (B,)
CHECK_DIM(1, paged_kv_indptr); // (B+1,)
Expand Down Expand Up @@ -144,7 +150,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
torch::Tensor lse;
if (return_lse) {
Expand Down
47 changes: 37 additions & 10 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,18 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) {
CHECK_INPUT(workspace_buffer);
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
CHECK_CONTIGUOUS(paged_kv_indptr);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, workspace_buffer);
qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
paged_kv_indptr = paged_kv_indptr.to(torch::kCPU).to(torch::kInt32);

auto device = workspace_buffer.device();
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
handler_->SetCUDAStream(torch_current_stream);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
Expand Down Expand Up @@ -68,6 +69,12 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
CHECK_INPUT(paged_kv_indptr);
CHECK_INPUT(paged_kv_indices);
CHECK_INPUT(paged_kv_last_page_len);
auto device = q.device();
CHECK_EQ(device, qo_indptr.device());
CHECK_EQ(device, paged_kv_data.device());
CHECK_EQ(device, paged_kv_indptr.device());
CHECK_EQ(device, paged_kv_indices.device());
CHECK_EQ(device, paged_kv_last_page_len.device());
CHECK_DIM(3, q); // (nnz_qo, H_qo, D)
CHECK_DIM(1, qo_indptr); // (B + 1,)
// [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND
Expand Down Expand Up @@ -100,7 +107,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
paged_kv_indices = paged_kv_indices.to(torch::kInt32);
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand Down Expand Up @@ -171,6 +178,14 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
CHECK_INPUT(paged_kv_last_page_len);
CHECK_INPUT(custom_mask);
CHECK_INPUT(qk_indptr);
auto device = q.device();
CHECK_EQ(device, qo_indptr.device());
CHECK_EQ(device, paged_kv_data.device());
CHECK_EQ(device, paged_kv_indptr.device());
CHECK_EQ(device, paged_kv_indices.device());
CHECK_EQ(device, paged_kv_last_page_len.device());
CHECK_EQ(device, custom_mask.device());
CHECK_EQ(device, qk_indptr.device());
CHECK_DIM(3, q); // (nnz_qo, H_qo, D)
CHECK_DIM(1, qo_indptr); // (B + 1,)
// [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND
Expand Down Expand Up @@ -207,7 +222,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32);
qk_indptr = qk_indptr.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand Down Expand Up @@ -267,17 +282,17 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, torch::Tensor empty_q_data) {
CHECK_INPUT(workspace_buffer);
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, workspace_buffer);

qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
kv_indptr = kv_indptr.to(torch::kCPU).to(torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto device = workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
handler_->SetCUDAStream(torch_current_stream);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
Expand Down Expand Up @@ -309,6 +324,11 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(kv_indptr);
auto device = q.device();
CHECK_EQ(device, qo_indptr.device());
CHECK_EQ(device, k.device());
CHECK_EQ(device, v.device());
CHECK_EQ(device, kv_indptr.device());
CHECK_DIM(3, q); // (nnz_qo, H_qo, D)
CHECK_DIM(1, qo_indptr); // (B + 1,)
CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D)
Expand All @@ -330,7 +350,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
qo_indptr = qo_indptr.to(torch::kInt32);
kv_indptr = kv_indptr.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand Down Expand Up @@ -396,6 +416,13 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
CHECK_INPUT(kv_indptr);
CHECK_INPUT(custom_mask);
CHECK_INPUT(qk_indptr);
auto device = q.device();
CHECK_EQ(device, qo_indptr.device());
CHECK_EQ(device, k.device());
CHECK_EQ(device, v.device());
CHECK_EQ(device, kv_indptr.device());
CHECK_EQ(device, custom_mask.device());
CHECK_EQ(device, qk_indptr.device());
CHECK_DIM(3, q); // (nnz_qo, H_qo, D)
CHECK_DIM(1, qo_indptr); // (B + 1,)
CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D)
Expand All @@ -421,7 +448,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
kv_indptr = kv_indptr.to(torch::kInt32);
qk_indptr = qk_indptr.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
Expand Down
16 changes: 13 additions & 3 deletions python/csrc/cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
CHECK_INPUT(s_a);
CHECK_INPUT(v_b);
CHECK_INPUT(s_b);
auto device = v_a.device();
CHECK_EQ(s_a.device(), device);
CHECK_EQ(v_b.device(), device);
CHECK_EQ(s_b.device(), device);
CHECK_DIM(3, v_a);
CHECK_DIM(2, s_a);
CHECK_DIM(3, v_b);
Expand All @@ -39,7 +43,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
unsigned int seq_len = v_a.size(0);
unsigned int num_heads = v_a.size(1);
unsigned int head_dim = v_a.size(2);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto v_merged = torch::empty_like(v_a, v_a.options());
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());

Expand All @@ -64,6 +68,10 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
CHECK_INPUT(s);
CHECK_INPUT(v_other);
CHECK_INPUT(s_other);
auto device = v.device();
CHECK_EQ(s.device(), device);
CHECK_EQ(v_other.device(), device);
CHECK_EQ(s_other.device(), device);
CHECK_DIM(3, v);
CHECK_DIM(2, s);
CHECK_DIM(3, v_other);
Expand All @@ -77,7 +85,7 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
unsigned int seq_len = v.size(0);
unsigned int num_heads = v.size(1);
unsigned int head_dim = v.size(2);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] {
cudaError_t status = MergeStateInPlace(
Expand All @@ -95,6 +103,8 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
CHECK_INPUT(v);
CHECK_INPUT(s);
auto device = v.device();
CHECK_EQ(s.device(), device);
CHECK_DIM(4, v);
CHECK_DIM(3, s);
CHECK_EQ(v.size(0), s.size(0));
Expand All @@ -105,7 +115,7 @@ std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
unsigned int num_heads = v.size(2);
unsigned int head_dim = v.size(3);
s = s.to(torch::kFloat32);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto v_merged = torch::empty({seq_len, num_heads, head_dim}, v.options());
auto s_merged = torch::empty({seq_len, num_heads}, s.options());

Expand Down
14 changes: 9 additions & 5 deletions python/csrc/group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor seg_indptr
unsigned int batch_size,
bool weight_column_major) {
// TODO(Zihao): Add more checks here
CHECK_CUDA(seg_indptr);
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_INPUT(seg_indptr);
CHECK_INPUT(x);
CHECK_INPUT(weight);
auto device = x.device();
CHECK_EQ(seg_indptr.device(), device);
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, x); // x: [sum(m_i), d_in]
CHECK_DIM(3, weight); // weight: [num_weights, d_out, d_in] if weight_column_major, [num_weights,
// d_in, d_out] otherwise
Expand All @@ -42,12 +45,13 @@ torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor seg_indptr
int64_t d_in = weight_column_major ? weight.size(2) : weight.size(1);
CHECK_EQ(x.size(1), d_in);
auto y = torch::zeros({cumulative_batch_size, d_out}, x.options());
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
seg_indptr = seg_indptr.to(torch::kInt64);

bool weight_indices_defined = weight_indices.numel() > 0;
if (weight_indices_defined) {
CHECK_CUDA(weight_indices);
CHECK_INPUT(weight_indices);
CHECK_EQ(weight_indices.device(), device);
weight_indices = weight_indices.to(torch::kInt64);
}

Expand Down
4 changes: 3 additions & 1 deletion python/csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ using namespace flashinfer;
torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps) {
CHECK_INPUT(x);
CHECK_INPUT(w);
auto device = x.device();
CHECK_EQ(w.device(), device);
CHECK_DIM(2, x); // x: (batch_size, hidden_size)
CHECK_DIM(1, w); // w: (hidden_size)
CHECK_EQ(x.size(1), w.size(0));
unsigned int batch_size = x.size(0);
unsigned int hidden_size = x.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto y = torch::empty_like(x);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] {
cudaError_t status = norm::RMSNorm(
Expand Down
9 changes: 8 additions & 1 deletion python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32);
CHECK_EQ(kv_indices.scalar_type(), torch::kInt32);
CHECK_EQ(kv_last_page_len.scalar_type(), torch::kInt32);
auto device = append_indptr.device();
CHECK_EQ(append_key.device(), device);
CHECK_EQ(append_value.device(), device);
CHECK_EQ(kv_data.device(), device);
CHECK_EQ(kv_indices.device(), device);
CHECK_EQ(kv_indptr.device(), device);
CHECK_EQ(kv_last_page_len.device(), device);

constexpr PageStorage page_storage = PageStorage::kIndices;
QKVLayout kv_layout = QKVLayout(layout);
Expand All @@ -65,7 +72,7 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
CHECK_EQ(append_value.size(1), num_heads);
CHECK_EQ(append_key.size(2), head_dim);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
Expand Down
8 changes: 6 additions & 2 deletions python/csrc/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ using namespace flashinfer;

torch::Tensor packbits(torch::Tensor x, const std::string& bitorder) {
CHECK_INPUT(x);
auto device = x.device();
TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'");
x = x.to(torch::kBool);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

int64_t num_elements = x.numel();
int64_t num_output_elements = (num_elements + 7) / 8;
Expand All @@ -46,6 +47,9 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
CHECK_INPUT(x);
CHECK_INPUT(input_indptr);
CHECK_INPUT(output_indptr);
auto device = x.device();
CHECK_EQ(input_indptr.device(), device);
CHECK_EQ(output_indptr.device(), device);
TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'");
unsigned int batch_size = input_indptr.size(0) - 1;
CHECK_EQ(output_indptr.size(0), batch_size + 1);
Expand All @@ -59,6 +63,6 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
static_cast<int32_t*>(input_indptr.data_ptr()),
static_cast<int32_t*>(output_indptr.data_ptr()), batch_size,
bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle,
c10::cuda::getCurrentCUDAStream());
c10::cuda::getCurrentCUDAStream(device.index()));
return y;
}
Loading

0 comments on commit 1b84fab

Please sign in to comment.