Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: add device guard for kernels #611

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/csrc/bmm_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
auto workspace_buffer = torch::empty(
{32 * 1024 * 1024}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device()));
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
const at::cuda::OptionalCUDAGuard device_guard(A.device());
auto stream = at::cuda::getCurrentCUDAStream();

// PyTorch is row major by default. cuBLASLt is column major by default.
Expand Down
5 changes: 5 additions & 0 deletions python/csrc/cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
unsigned int seq_len = v_a.size(0);
unsigned int num_heads = v_a.size(1);
unsigned int head_dim = v_a.size(2);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto v_merged = torch::empty_like(v_a, v_a.options());
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());
Expand Down Expand Up @@ -91,6 +93,8 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
unsigned int seq_len = v.size(0);
unsigned int num_heads = v.size(1);
unsigned int head_dim = v.size(2);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] {
Expand Down Expand Up @@ -121,6 +125,7 @@ std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
unsigned int num_heads = v.size(2);
unsigned int head_dim = v.size(3);
s = s.to(torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto v_merged = torch::empty({seq_len, num_heads, head_dim}, v.options());
auto s_merged = torch::empty({seq_len, num_heads}, s.options());
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_proble
torch::Tensor empty_x_data, bool weight_column_major) {
unsigned int batch_size = x_ptr.size(0);
auto device = workspace_buffer.device();

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_x_data.scalar_type(), c_type, [&] {
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/group_gemm_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ void CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer,
bool weight_column_major) {
unsigned int batch_size = x_ptr.size(0);
auto device = float_workspace_buffer.device();

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_x_data.scalar_type(), c_type, [&] {
Expand Down
4 changes: 4 additions & 0 deletions python/csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight,
CHECK_EQ(output.size(0), batch_size);
CHECK_EQ(output.size(1), hidden_size);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()),
Expand Down Expand Up @@ -61,6 +62,7 @@ void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Ten
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()),
Expand All @@ -86,6 +88,7 @@ void gemma_rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& w
CHECK_EQ(output.size(0), batch_size);
CHECK_EQ(output.size(1), hidden_size);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::GemmaRMSNorm(static_cast<c_type*>(input.data_ptr()),
Expand Down Expand Up @@ -115,6 +118,7 @@ void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torc
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::GemmaFusedAddRMSNorm(
Expand Down
3 changes: 2 additions & 1 deletion python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
CHECK_EQ(append_key.size(2), head_dim);
CHECK_EQ(append_value.size(1), num_heads);
CHECK_EQ(append_value.size(2), head_dim);


const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

auto kv_scalar_dtype = paged_k_cache.scalar_type();
Expand Down
1 change: 1 addition & 0 deletions python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
Expand Down
3 changes: 3 additions & 0 deletions python/csrc/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder) {
auto device = x.device();
TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'");
x = x.to(torch::kBool);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

int64_t num_elements = x.numel();
Expand Down Expand Up @@ -57,6 +59,7 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
int64_t output_nnz = output_indptr[batch_size].item<int64_t>();
auto y = torch::empty({output_nnz}, x.options().dtype(torch::kUInt8));

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaError_t status = quantization::SegmentPackBits(
static_cast<bool*>(x.data_ptr()), static_cast<uint8_t*>(y.data_ptr()),
static_cast<int32_t*>(input_indptr.data_ptr()),
Expand Down
7 changes: 6 additions & 1 deletion python/csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T
size_t k_rope_stride_h = k_rope.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);


const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotary(
Expand Down Expand Up @@ -93,6 +94,7 @@ void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
size_t k_rope_stride_h = k_rope.stride(1);
pos_ids = pos_ids.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryPosIds(
Expand Down Expand Up @@ -145,6 +147,7 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T
size_t k_rope_stride_h = k_rope.stride(1);
pos_ids = pos_ids.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
Expand Down Expand Up @@ -195,6 +198,7 @@ void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyLlama31Rotary(
Expand Down Expand Up @@ -240,6 +244,7 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor
size_t k_rope_stride_h = k_rope.stride(1);
pos_ids = pos_ids.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyLlama31RotaryPosIds(
Expand Down
11 changes: 10 additions & 1 deletion python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));

Expand Down Expand Up @@ -71,6 +72,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
uniform_samples = uniform_samples.to(torch::kFloat32);
top_p_arr = top_p_arr.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
Expand Down Expand Up @@ -112,6 +114,7 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
uniform_samples = uniform_samples.to(torch::kFloat32);
top_k_arr = top_k_arr.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
Expand Down Expand Up @@ -153,6 +156,7 @@ std::vector<torch::Tensor> min_p_sampling_from_probs(torch::Tensor probs,
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
Expand Down Expand Up @@ -203,6 +207,7 @@ std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
Expand Down Expand Up @@ -236,7 +241,8 @@ torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
}
top_p_arr = top_p_arr.to(torch::kFloat32);
probs = probs.to(torch::kFloat32);


const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto renorm_probs =
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
Expand Down Expand Up @@ -268,6 +274,7 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
top_k_arr = top_k_arr.to(torch::kInt32);
probs = probs.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto renorm_probs =
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
Expand Down Expand Up @@ -300,6 +307,7 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tenso
top_k_arr = top_k_arr.to(torch::kInt32);
logits = logits.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto mask_logits =
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
Expand Down Expand Up @@ -348,6 +356,7 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
uniform_samples = uniform_samples.to(torch::kFloat32);
target_probs = target_probs.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1},
torch::dtype(torch::kInt32).device(device));
Expand Down
2 changes: 2 additions & 0 deletions python/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");

Expand Down Expand Up @@ -112,6 +113,7 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun(
}
uint32_t head_dim = q.size(2);

const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
if (maybe_lse) {
Expand Down
1 change: 1 addition & 0 deletions python/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();

auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
Expand Down
1 change: 1 addition & 0 deletions python/csrc_aot/single_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
kv_len = k.size(1);
}
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q);

Expand Down
1 change: 1 addition & 0 deletions python/csrc_aot/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ torch::Tensor single_prefill_with_kv_cache(
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down
4 changes: 3 additions & 1 deletion python/flashinfer/jit/batch_decode_mla_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
indptr = indptr.to(torch::kCPU);

Expand Down Expand Up @@ -83,8 +84,9 @@
auto device = q_nope.device();
int64_t batch_size = q_nope.size(0);
int64_t num_qo_heads = q_nope.size(1);
int64_t page_size = paged_ckv_cache.size(1);;
int64_t page_size = paged_ckv_cache.size(1);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q_nope);
torch::Tensor lse;
Expand Down
4 changes: 3 additions & 1 deletion python/flashinfer/jit/batch_decode_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");

Expand Down Expand Up @@ -93,7 +94,8 @@
page_size = paged_k_cache.size(1);
num_kv_heads = paged_k_cache.size(2);
}


const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
if (maybe_lse) {
Expand Down
3 changes: 3 additions & 0 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();

auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
Expand Down Expand Up @@ -92,6 +93,7 @@
}

auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down Expand Up @@ -187,6 +189,7 @@
num_kv_heads = paged_k_cache.size(2);
}

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down
2 changes: 2 additions & 0 deletions python/flashinfer/jit/single_decode_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
num_kv_heads = k.size(0);
kv_len = k.size(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q);

Expand Down Expand Up @@ -157,6 +158,7 @@
num_kv_heads = k.size(0);
kv_len = k.size(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q);

Expand Down
2 changes: 2 additions & 0 deletions python/flashinfer/jit/single_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down Expand Up @@ -177,6 +178,7 @@
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down