diff --git a/.gitignore b/.gitignore index fa13a77b7..197301443 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,7 @@ src/generated/ python/csrc/generated/ python/flashinfer/_build_meta.py python/flashinfer/jit/aot_config.py -python/csrc_aot/generated/ +python/csrc-aot/generated/ # Package files python/flashinfer/data/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5633f788a..01a3086ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,7 @@ repos: - id: clang-format types_or: [c++, c, cuda] exclude: | - (?x)^(3rdparty/.* src/generated/.* python/flashinfer/jit/aot_config.py python/csrc_aot/generated/.*)$ + (?x)^(3rdparty/.* src/generated/.* python/flashinfer/jit/aot_config.py)$ - repo: https://github.com/cheshirekow/cmake-format-precommit rev: v0.6.13 diff --git a/include/flashinfer/allocator.h b/include/flashinfer/allocator.h index 7a1e40375..7bd1a854e 100644 --- a/include/flashinfer/allocator.h +++ b/include/flashinfer/allocator.h @@ -18,7 +18,8 @@ #include #include -#include + +#include "exception.h" namespace flashinfer { @@ -44,7 +45,7 @@ struct AlignedAllocator { std::ostringstream oss; oss << "Failed to allocate memory for " << name << " with size " << size << " and alignment " << alignment << " in AlignedAllocator"; - throw std::runtime_error(oss.str()); + FLASHINFER_ERROR(oss.str()); } return nullptr; } diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 6860f8f9a..c8543b3f9 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -687,7 +687,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(typename AttentionVariant::ParamsT if (nblks.x == 0 || nblks.y == 0) { std::ostringstream err_msg; err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")"; - throw std::runtime_error(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } dim3 nthrs = dim3(bdx, bdy, bdz); float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM); diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 750e150b9..4d8c62ad5 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1375,7 +1375,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " "to qo_len, got kv_len" << kv_len << " and qo_len " << qo_len; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } const uint32_t group_size = num_qo_heads / num_kv_heads; @@ -1442,7 +1442,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } else { constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; constexpr uint32_t num_rows_per_cta = NUM_FRAGS_Q * NUM_WARPS_Q * 16; @@ -2165,7 +2165,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } else { // TODO(Zihao): fix the following computation uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) + @@ -2267,7 +2267,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } else { // TODO(Zihao): fix the following computation uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) + diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 64e82413a..6ac3d06ad 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -333,7 +333,7 @@ struct DecodePlanInfo { if (vec.size() != 10) { std::ostringstream err_msg; err_msg << "DecodePlanInfo::FromVector: vec.size() should be 10, but got " << vec.size(); - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } padded_batch_size = vec[0]; v_offset = vec[1]; @@ -440,14 +440,14 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin std::ostringstream err_msg; err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" << qo_indptr_h[i] << " should be non-negative"; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); if (kv_len_arr[i] < 0) { std::ostringstream err_msg; err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] << " - kv_indptr[" << i << "]" << kv_indptr_h[i] << " should be non-negative"; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } sum_packed_qo_len += packed_qo_len_arr[i]; } @@ -570,7 +570,7 @@ struct PrefillPlanInfo { if (vec.size() != 14) { std::ostringstream err_msg; err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 14, but got " << vec.size(); - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } padded_batch_size = vec[0]; total_num_rows = vec[1]; @@ -601,7 +601,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " << num_kv_heads; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } // step 0: get the number of SMs diff --git a/include/flashinfer/exception.h b/include/flashinfer/exception.h new file mode 100644 index 000000000..9d4f9d783 --- /dev/null +++ b/include/flashinfer/exception.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_EXCEPTION_H_ +#define FLASHINFER_EXCEPTION_H_ + +#include +#include + +namespace flashinfer { + +class Error : public std::exception { + private: + std::string message_; + + public: + Error(const std::string& func, const std::string& file, int line, const std::string& message) { + std::ostringstream oss; + oss << "Error in function '" << func << "' " + << "at " << file << ":" << line << ": " << message; + message_ = oss.str(); + } + + virtual const char* what() const noexcept override { return message_.c_str(); } +}; + +#define FLASHINFER_ERROR(message) throw Error(__FUNCTION__, __FILE__, __LINE__, message) + +#define FLASHINFER_CHECK(condition, message) \ + if (!(condition)) { \ + FLASHINFER_ERROR(message); \ + } + +} // namespace flashinfer + +#endif // FLASHINFER_EXCEPTION_H_ diff --git a/include/flashinfer/gemm/bmm_fp8.cuh b/include/flashinfer/gemm/bmm_fp8.cuh index 0be406194..853d803ce 100644 --- a/include/flashinfer/gemm/bmm_fp8.cuh +++ b/include/flashinfer/gemm/bmm_fp8.cuh @@ -19,15 +19,17 @@ #include #include -#include +#include +#include #include -#define FLASHINFER_CUBLAS_CHECK(EXPR) \ - { \ - cublasStatus_t e = (EXPR); \ - if (e != CUBLAS_STATUS_SUCCESS) { \ - throw std::runtime_error("CUBLAS Error: " + std::string(cublasGetStatusString(e))); \ - } \ +#include "../exception.h" + +#define FLASHINFER_CUBLAS_CHECK(EXPR) \ + { \ + cublasStatus_t e = (EXPR); \ + FLASHINFER_CHECK(e == CUBLAS_STATUS_SUCCESS, \ + "CUBLAS Error: " + std::string(cublasGetStatusString(e))); \ } #ifndef NDEBUG @@ -131,7 +133,7 @@ cudaDataType_t get_cuda_data_type() { } else if constexpr (std::is_same_v) { return CUDA_R_16F; } else { - throw std::runtime_error("Unsupported type"); + FLASHINFER_ERROR("Unsupported type"); } } @@ -155,7 +157,7 @@ cublasStatus_t bmm_fp8_internal_cublaslt(void* workspace, size_t workspace_size_ cudaDataType_t b_type = get_cuda_data_type(); cudaDataType_t d_type = get_cuda_data_type
(); if (std::is_same_v && std::is_same_v) { - throw std::runtime_error("Unsupported combination: both A and B are e5m2"); + FLASHINFER_ERROR("Unsupported combination: both A and B are e5m2"); } auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true); diff --git a/include/flashinfer/gemm/group_gemm.cuh b/include/flashinfer/gemm/group_gemm.cuh index c60b3d89e..3c142eea5 100644 --- a/include/flashinfer/gemm/group_gemm.cuh +++ b/include/flashinfer/gemm/group_gemm.cuh @@ -79,13 +79,13 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe if (status != cutlass::Status::kSuccess) { std::ostringstream err_msg; err_msg << "cutlass group_gemm.initialize failed: " << cutlassGetStatusString(status); - throw std::runtime_error(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } status = gemm.run(stream); if (status != cutlass::Status::kSuccess) { std::ostringstream err_msg; err_msg << "cutlass group_gemm.run failed: " << cutlassGetStatusString(status); - throw std::runtime_error(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } }); diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh index 9560cb9ee..7c1644141 100644 --- a/include/flashinfer/gemm/group_gemm_sm90.cuh +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -73,7 +73,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si sizeof(DTypeIn) == 1) { std::ostringstream err_msg; err_msg << "Row-major layout is not supported for fp8 data type"; - throw std::runtime_error(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } else { using LayoutA = cutlass::layout::RowMajor; constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; diff --git a/include/flashinfer/math.cuh b/include/flashinfer/math.cuh index 6ec7dbc12..27c6351e8 100644 --- a/include/flashinfer/math.cuh +++ b/include/flashinfer/math.cuh @@ -19,6 +19,8 @@ #include #include +#include + namespace flashinfer { namespace math { diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index edce23cc6..a4841a319 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -23,10 +23,10 @@ #include #include -#include -#include #include +#include "exception.h" + #define STR_HELPER(x) #x #define STR(x) STR_HELPER(x) @@ -57,7 +57,7 @@ #define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, ...) \ if (allow_fp16_qk_reduction) { \ - throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \ + FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ } else { \ constexpr bool ALLOW_FP16_QK_REDUCTION = false; \ __VA_ARGS__ \ @@ -73,7 +73,7 @@ } else { \ std::ostringstream err_msg; \ err_msg << "Unsupported num_frags_q: " << num_frags_q; \ - throw std::invalid_argument(err_msg.str()); \ + FLASHINFER_ERROR(err_msg.str()); \ } #define DISPATCH_NUM_FRAGS_KV(max_frags_kv, NUM_FRAGS_KV, ...) \ @@ -92,7 +92,7 @@ } else { \ std::ostringstream err_msg; \ err_msg << "Unsupported max_frags_kv: " << max_frags_kv; \ - throw std::invalid_argument(err_msg.str()); \ + FLASHINFER_ERROR(err_msg.str()); \ } #define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \ @@ -115,7 +115,7 @@ default: { \ std::ostringstream err_msg; \ err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \ - throw std::invalid_argument(err_msg.str()); \ + FLASHINFER_ERROR(err_msg.str()); \ } \ } @@ -138,7 +138,7 @@ } else { \ std::ostringstream err_msg; \ err_msg << "Unsupported group_size: " << group_size; \ - throw std::invalid_argument(err_msg.str()); \ + FLASHINFER_ERROR(err_msg.str()); \ } #define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ @@ -161,7 +161,7 @@ default: { \ std::ostringstream err_msg; \ err_msg << "Unsupported mask_mode: " << int(mask_mode); \ - throw std::invalid_argument(err_msg.str()); \ + FLASHINFER_ERROR(err_msg.str()); \ } \ } @@ -190,7 +190,7 @@ default: { \ std::ostringstream err_msg; \ err_msg << "Unsupported head_dim: " << head_dim; \ - throw std::invalid_argument(err_msg.str()); \ + FLASHINFER_ERROR(err_msg.str()); \ } \ } @@ -214,7 +214,7 @@ default: { \ std::ostringstream err_msg; \ err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \ - throw std::invalid_argument(err_msg.str()); \ + FLASHINFER_ERROR(err_msg.str()); \ } \ } @@ -248,7 +248,7 @@ default: { \ std::ostringstream err_msg; \ err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ - throw std::invalid_argument(err_msg.str()); \ + FLASHINFER_ERROR(err_msg.str()); \ } \ } diff --git a/python/aot_MANIFEST.in b/python/aot_MANIFEST.in index e19887692..d28673e14 100644 --- a/python/aot_MANIFEST.in +++ b/python/aot_MANIFEST.in @@ -2,7 +2,6 @@ prune */__pycache__ prune csrc -prune csrc_aot exclude aot_setup.py exclude setup.py diff --git a/python/aot_setup.py b/python/aot_setup.py index 670c2cb8b..3f24eec1f 100644 --- a/python/aot_setup.py +++ b/python/aot_setup.py @@ -64,7 +64,7 @@ def write_if_different(path: pathlib.Path, content: str) -> None: def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: - path = root / "python" / "csrc_aot" / "generated" + path = root / "python" / "csrc" / "generated" path.mkdir(parents=True, exist_ok=True) head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") @@ -423,12 +423,12 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: "csrc/quantization.cu", "csrc/rope.cu", "csrc/sampling.cu", - "csrc_aot/activation.cu", - "csrc_aot/batch_decode.cu", - "csrc_aot/batch_prefill.cu", - "csrc_aot/flashinfer_ops.cu", - "csrc_aot/single_decode.cu", - "csrc_aot/single_prefill.cu", + "csrc/activation.cu", + "csrc/batch_decode.cu", + "csrc/batch_prefill.cu", + "csrc/single_decode.cu", + "csrc/single_prefill.cu", + "csrc/flashinfer_ops.cu", ] + files_decode + files_prefill, diff --git a/python/csrc_aot/activation.cu b/python/csrc/activation.cu similarity index 78% rename from python/csrc_aot/activation.cu rename to python/csrc/activation.cu index bb2126b18..dba033597 100644 --- a/python/csrc_aot/activation.cu +++ b/python/csrc/activation.cu @@ -13,12 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include - #include -#include "pytorch_extension_utils.h" +#include "aot_extension_utils.h" using namespace flashinfer; @@ -35,13 +32,12 @@ __device__ __forceinline__ float gelu_tanh(const float& val) { return val * cdf; } -void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); @@ -52,13 +48,12 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { }); } -void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) { +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); @@ -69,13 +64,12 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) { }); } -void gelu_and_mul(torch::Tensor& out, torch::Tensor& input) { +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); diff --git a/python/csrc/aot_extension_utils.h b/python/csrc/aot_extension_utils.h new file mode 100644 index 000000000..76db0168d --- /dev/null +++ b/python/csrc/aot_extension_utils.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "generated/dispatch.inc" +#include "pytorch_extension_utils.h" + +#define DISPATCH_head_dim(expr, const_expr, ...) \ + _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) + +#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("positional encoding mode", expr, \ + _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) + +#define DISPATCH_allow_fp16_qk_reduction(expr, const_expr, ...) \ + _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ + _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) + +#define DISPATCH_mask_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) + +#define DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ + [&]() -> bool { \ + if (use_logits_soft_cap) { \ + constexpr bool USE_LOGITS_SOFT_CAP = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool USE_LOGITS_SOFT_CAP = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_dtype, kv_dtype, c_type_q, c_type_kv, ...) \ + [&]() -> bool { \ + if (kv_dtype == q_dtype) { \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_dtype, c_type_q, [&] { \ + using c_type_kv = c_type_q; \ + return __VA_ARGS__(); \ + }); \ + } else { \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_dtype, c_type_q, [&] { \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_dtype, c_type_kv, \ + [&] { return __VA_ARGS__(); }); \ + }); \ + } \ + }() diff --git a/python/csrc_aot/batch_decode.cu b/python/csrc/batch_decode.cu similarity index 82% rename from python/csrc_aot/batch_decode.cu rename to python/csrc/batch_decode.cu index 3d4747851..19cea2c8b 100644 --- a/python/csrc_aot/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -13,15 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include - #include #include #include #include -#include "pytorch_extension_utils.h" +#include "aot_extension_utils.h" namespace flashinfer { @@ -32,31 +29,28 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::Par } // namespace flashinfer +using namespace flashinfer; + std::vector BatchDecodeWithPagedKVCachePlan( - bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, - torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, - torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph) { + bool use_logits_soft_cap, unsigned int head_dim, at::Tensor empty_q_data, + at::Tensor empty_kv_data, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - auto device = float_workspace_buffer.device(); - const at::cuda::CUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU"); DecodePlanInfo plan_info; using IdType = int32_t; - // check indptr has idtype int32 - TORCH_CHECK(indptr.scalar_type() == torch::kInt32, "indptr must be int32"); constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; auto q_scalar_type = empty_q_data.scalar_type(); auto kv_scalar_type = empty_kv_data.scalar_type(); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { using DTypeQ = q_type; using DTypeKV = kv_type; @@ -77,7 +71,7 @@ std::vector BatchDecodeWithPagedKVCachePlan( static_cast(page_locked_int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes, plan_info, static_cast(indptr.data_ptr()), batch_size, num_qo_heads, page_size, enable_cuda_graph, - /*stream=*/torch_current_stream, work_estimation_func); + /*stream=*/stream, work_estimation_func); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", cudaGetErrorString(status)); @@ -90,13 +84,13 @@ std::vector BatchDecodeWithPagedKVCachePlan( return plan_info.ToVector(); } -torch::Tensor BatchDecodeWithPagedKVCacheRun( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, +void BatchDecodeWithPagedKVCacheRun( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, std::optional alibi_slopes, at::Tensor o, unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse) { + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream) { DecodePlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(kv_layout_code); @@ -114,14 +108,10 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun( } uint32_t head_dim = q.size(2); - const at::cuda::CUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - torch::Tensor o = torch::empty_like(q); if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1)); - TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); @@ -147,6 +137,8 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun( TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); kv_cache_strides = k_strides.data(); + cudaStream_t stream = reinterpret_cast(cuda_stream); + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { using DTypeQ = q_type; using DTypeKV = kv_type; @@ -194,13 +186,11 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun( cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheDispatched( - params, tmp_v, tmp_s, /*stream=*/torch_current_stream); + params, tmp_v, tmp_s, /*stream=*/stream); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", cudaGetErrorString(status)); return true; }); }); }); - - return o; } diff --git a/python/csrc_aot/batch_prefill.cu b/python/csrc/batch_prefill.cu similarity index 82% rename from python/csrc_aot/batch_prefill.cu rename to python/csrc/batch_prefill.cu index f0d13489c..970bf0c9e 100644 --- a/python/csrc_aot/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -13,16 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include - #include #include #include #include #include -#include "pytorch_extension_utils.h" +#include "aot_extension_utils.h" namespace flashinfer { @@ -40,32 +37,29 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P } // namespace flashinfer +using namespace flashinfer; + std::vector BatchPrefillWithKVCachePlan( - unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph) { + unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, + unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - auto device = float_workspace_buffer.device(); - const at::cuda::CUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU"); - TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU"); - PrefillPlanInfo plan_info; using IdType = int32_t; + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = PrefillPlan( float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), kv_indptr.data_ptr(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, - enable_cuda_graph, /*sizeof_dtype_o=*/2, torch_current_stream); + enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); TORCH_CHECK(status == cudaSuccess, "Failed to plan prefill with error: ", cudaGetErrorString(status)); @@ -73,14 +67,13 @@ std::vector BatchPrefillWithKVCachePlan( return plan_info.ToVector(); } -torch::Tensor BatchPrefillWithRaggedKVCacheRun( - unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse) { +void BatchPrefillWithRaggedKVCacheRun( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream) { PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); @@ -97,14 +90,10 @@ torch::Tensor BatchPrefillWithRaggedKVCacheRun( kv_stride_n = k.stride(1); } - auto device = float_workspace_buffer.device(); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q, q.options()); if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); - TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } void* float_buffer_ptr = float_workspace_buffer.data_ptr(); @@ -118,6 +107,7 @@ torch::Tensor BatchPrefillWithRaggedKVCacheRun( auto q_scalar_type = q.scalar_type(); auto kv_scalar_type = k.scalar_type(); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { using DTypeQ = q_type; using DTypeKV = kv_type; @@ -178,8 +168,8 @@ torch::Tensor BatchPrefillWithRaggedKVCacheRun( DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched< CTA_TILE_Q, HEAD_DIM, POS_ENCODING_MODE, - /*use_fp16_qk_reduction=*/false, MASK_MODE, RaggedAttentionVariant>( - params, tmp_v, tmp_s, torch_current_stream); + /*use_fp16_qk_reduction=*/false, MASK_MODE, RaggedAttentionVariant>(params, tmp_v, + tmp_s, stream); }); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", @@ -189,19 +179,17 @@ torch::Tensor BatchPrefillWithRaggedKVCacheRun( }); }); }); - - return o; } -torch::Tensor BatchPrefillWithPagedKVCacheRun( - unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, - std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, - torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, - unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse) { +void BatchPrefillWithPagedKVCacheRun( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse, int64_t cuda_stream) { PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); @@ -218,13 +206,10 @@ torch::Tensor BatchPrefillWithPagedKVCacheRun( num_kv_heads = paged_k_cache.size(2); } - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q, q.options()); if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); - TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); @@ -248,6 +233,8 @@ torch::Tensor BatchPrefillWithPagedKVCacheRun( TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); kv_cache_strides = k_strides.data(); + cudaStream_t stream = reinterpret_cast(cuda_stream); + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { using DTypeQ = q_type; using DTypeKV = kv_type; @@ -311,8 +298,8 @@ torch::Tensor BatchPrefillWithPagedKVCacheRun( DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { status = flashinfer::BatchPrefillWithPagedKVCacheDispatched< CTA_TILE_Q, HEAD_DIM, POS_ENCODING_MODE, - /*use_fp16_qk_reduction=*/false, MASK_MODE, PagedAttentionVariant>( - params, tmp_v, tmp_s, torch_current_stream); + /*use_fp16_qk_reduction=*/false, MASK_MODE, PagedAttentionVariant>(params, tmp_v, + tmp_s, stream); }); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", @@ -322,6 +309,4 @@ torch::Tensor BatchPrefillWithPagedKVCacheRun( }); }); }); - - return o; } diff --git a/python/csrc/bmm_fp8.cu b/python/csrc/bmm_fp8.cu index 36dfa1517..568ac7a42 100644 --- a/python/csrc/bmm_fp8.cu +++ b/python/csrc/bmm_fp8.cu @@ -14,15 +14,14 @@ * limitations under the License. */ -#include -#include +#include #include #include "pytorch_extension_utils.h" -void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, - torch::Tensor& A_scale, torch::Tensor& B_scale) { +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream) { TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor"); TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor"); TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor"); @@ -33,30 +32,14 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, TORCH_CHECK(A.size(2) == B.size(1), "Incompatible matrix sizes"); TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2), "Result tensor has incorrect shape"); - TORCH_CHECK(A.scalar_type() == torch::kFloat8_e4m3fn || A.scalar_type() == torch::kFloat8_e5m2, - "A must be Float8_e4m3fn or Float8_e5m2"); - TORCH_CHECK(B.scalar_type() == torch::kFloat8_e4m3fn || B.scalar_type() == torch::kFloat8_e5m2, - "B must be Float8_e4m3fn or Float8_e5m2"); - TORCH_CHECK(D.scalar_type() == torch::kBFloat16 || D.scalar_type() == torch::kHalf, - "D must be BFloat16 or Half"); - - TORCH_CHECK(A_scale.scalar_type() == torch::kFloat32 && B_scale.scalar_type() == torch::kFloat32, - "A_scale and B_scale must be Float32"); auto batch_size = A.size(0); auto m = A.size(1); auto k = A.size(2); auto n = B.size(2); - // Per the cublas documentation, the recommended workspace buffer size for hopper is 32MB. - // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace - // create an empty buffer of 32MB, with data type uint8 and on the same device as A - auto workspace_buffer = torch::empty( - {32 * 1024 * 1024}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); - auto lt_handle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); - const at::cuda::OptionalCUDAGuard device_guard(A.device()); - auto stream = at::cuda::getCurrentCUDAStream(); - + auto lt_handle = reinterpret_cast(cublas_handle); + auto stream = reinterpret_cast(cuda_stream); // PyTorch is row major by default. cuBLASLt is column major by default. // We need row major D as expected. // A ^ T * B = D, so D ^ T = B ^ T * A diff --git a/python/csrc/cascade.cu b/python/csrc/cascade.cu index 97ac7cbc1..8ba2ad6b9 100644 --- a/python/csrc/cascade.cu +++ b/python/csrc/cascade.cu @@ -19,8 +19,8 @@ using namespace flashinfer; -std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, - torch::Tensor s_b) { +void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, + at::Tensor v_merged, at::Tensor s_merged, int64_t cuda_stream) { CHECK_INPUT(v_a); CHECK_INPUT(s_a); CHECK_INPUT(v_b); @@ -37,34 +37,28 @@ std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, tor CHECK_SHAPE(s_a, s_b); CHECK_EQ(v_a.size(0), s_a.size(0)); CHECK_EQ(v_a.size(1), s_b.size(1)); - s_a = s_a.to(torch::kFloat32); - s_b = s_b.to(torch::kFloat32); unsigned int seq_len = v_a.size(0); unsigned int num_heads = v_a.size(1); unsigned int head_dim = v_a.size(2); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto v_merged = torch::empty_like(v_a, v_a.options()); - auto s_merged = torch::empty({seq_len, num_heads}, s_a.options()); + cudaStream_t stream = reinterpret_cast(cuda_stream); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v_a.scalar_type(), c_type, [&] { - cudaError_t status = MergeState( - static_cast(v_a.data_ptr()), static_cast(s_a.data_ptr()), - static_cast(v_b.data_ptr()), static_cast(s_b.data_ptr()), - static_cast(v_merged.data_ptr()), static_cast(s_merged.data_ptr()), - seq_len, num_heads, head_dim, torch_current_stream); + cudaError_t status = + MergeState(static_cast(v_a.data_ptr()), static_cast(s_a.data_ptr()), + static_cast(v_b.data_ptr()), static_cast(s_b.data_ptr()), + static_cast(v_merged.data_ptr()), + static_cast(s_merged.data_ptr()), seq_len, num_heads, head_dim, stream); TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status)); return true; }); TORCH_CHECK(success, "MergeState kernel launch failed: unsupported data type"); - return {v_merged, s_merged}; } -void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other, - torch::Tensor s_other, std::optional mask) { +void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Tensor s_other, + std::optional mask, int64_t cuda_stream) { CHECK_INPUT(v); CHECK_INPUT(s); CHECK_INPUT(v_other); @@ -81,8 +75,6 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe CHECK_SHAPE(s, s_other); CHECK_EQ(v.size(0), s.size(0)); CHECK_EQ(v.size(1), s.size(1)); - CHECK_EQ(s.scalar_type(), torch::kFloat32); - CHECK_EQ(s_other.scalar_type(), torch::kFloat32); uint8_t* mask_ptr = nullptr; if (mask.has_value()) { CHECK_DIM(1, mask.value()); @@ -94,14 +86,12 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe unsigned int num_heads = v.size(1); unsigned int head_dim = v.size(2); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - + cudaStream_t stream = reinterpret_cast(cuda_stream); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] { cudaError_t status = MergeStateInPlace( static_cast(v.data_ptr()), static_cast(s.data_ptr()), static_cast(v_other.data_ptr()), static_cast(s_other.data_ptr()), seq_len, - num_heads, head_dim, mask_ptr, torch_current_stream); + num_heads, head_dim, mask_ptr, stream); TORCH_CHECK(status == cudaSuccess, "MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status)); return true; @@ -110,7 +100,8 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe TORCH_CHECK(success, "MergeStateInPlace kernel launch failed: unsupported data type"); } -std::vector merge_states(torch::Tensor v, torch::Tensor s) { +void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged, + int64_t cuda_stream) { CHECK_INPUT(v); CHECK_INPUT(s); auto device = v.device(); @@ -124,22 +115,17 @@ std::vector merge_states(torch::Tensor v, torch::Tensor s) { unsigned int num_index_sets = v.size(1); unsigned int num_heads = v.size(2); unsigned int head_dim = v.size(3); - s = s.to(torch::kFloat32); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto v_merged = torch::empty({seq_len, num_heads, head_dim}, v.options()); - auto s_merged = torch::empty({seq_len, num_heads}, s.options()); + cudaStream_t stream = reinterpret_cast(cuda_stream); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] { cudaError_t status = MergeStates( static_cast(v.data_ptr()), static_cast(s.data_ptr()), static_cast(v_merged.data_ptr()), static_cast(s_merged.data_ptr()), - num_index_sets, seq_len, num_heads, head_dim, torch_current_stream); + num_index_sets, seq_len, num_heads, head_dim, stream); TORCH_CHECK(status == cudaSuccess, "MergeStates kernel launch failed: ", cudaGetErrorString(status)); return true; }); TORCH_CHECK(success, "MergeStates kernel launch failed: unsupported data type"); - return {v_merged, s_merged}; } diff --git a/python/csrc/dispatch_type_code.h b/python/csrc/dispatch_type_code.h new file mode 100644 index 000000000..4f717b95a --- /dev/null +++ b/python/csrc/dispatch_type_code.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include + +#include + +using namespace flashinfer; + +enum class TypeCode { + kFloat64 = 0, + kFloat32 = 1, + kFloat16 = 2, + kBFloat16 = 3, + kFloat8_e4m3fn = 4, + kFloat8_e5m2 = 5, + kInt64 = 100, + kUInt64 = 101, + kInt32 = 102, + kUInt32 = 103, + kInt16 = 104, + kUInt16 = 105, + kInt8 = 106, + kUInt8 = 107, +}; + +#ifdef FLASHINFER_ENABLE_BF16 +#define DISPATCH_TYPE_CODE_TO_CTYPE_FP16(type_code, c_type, ...) \ + [&]() -> bool { \ + switch (TypeCode(type_code)) { \ + case TypeCode::kFloat16: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case TypeCode::kBFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ + FLASHINFER_ERROR(oss.str()); \ + return false; \ + } \ + }() +#else +#define DISPATCH_TYPE_CODE_TO_CTYPE_FP16(type_code, c_type, ...) \ + [&]() -> bool { \ + switch (TypeCode(type_code)) { \ + case TypeCode::kFloat16: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ + FLASHINFER_ERROR(oss.str()); \ + return false; \ + } \ + }() +#endif + +#ifdef FLASHINFER_ENABLE_FP8 +#define DISPATCH_TYPE_CODE_TO_CTYPE_FP8(type_code, c_type, ...) \ + [&]() -> bool { \ + switch (TypeCode(type_code)) { \ + case TypeCode::kFloat8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case TypeCode::kFloat8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ + FLASHINFER_ERROR(oss.str()); \ + return false; \ + } \ + }() +#else +#define DISPATCH_TYPE_CODE_TO_CTYPE_FP8(type_code, c_type, ...) \ + [&]() -> bool { \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ + FLASHINFER_ERROR(oss.str()); \ + return false; \ + }() +#endif + +#if defined(FLASHINFER_ENABLE_BF16) && defined(FLASHINFER_ENABLE_FP8) +#define DISPATCH_TYPE_CODE_TO_CTYPE(type_code, c_type, ...) \ + [&]() -> bool { \ + switch (TypeCode(type_code)) { \ + case TypeCode::kFloat16: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case TypeCode::kBFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + case TypeCode::kFloat8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case TypeCode::kFloat8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ + FLASHINFER_ERROR(oss.str()); \ + return false; \ + } \ + }() +#elif defined(FLASHINFER_ENABLE_BF16) +#define DISPATCH_TYPE_CODE_TO_CTYPE(type_code, c_type, ...) \ + [&]() -> bool { \ + switch (TypeCode(type_code)) { \ + case TypeCode::kFloat16: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case TypeCode::kBFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ + FLASHINFER_ERROR(oss.str()); \ + return false; \ + } \ + }() +#elif defined(FLASHINFER_ENABLE_FP8) +#define DISPATCH_TYPE_CODE_TO_CTYPE(type_code, c_type, ...) \ + [&]() -> bool { \ + switch (TypeCode(type_code)) { \ + case TypeCode::kFloat16: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case TypeCode::kFloat8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case TypeCode::kFloat8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ + FLASHINFER_ERROR(oss.str()); \ + return false; \ + } \ + }() +#else +#define DISPATCH_TYPE_CODE_TO_CTYPE(type_code, c_type, ...) \ + [&]() -> bool { \ + switch (TypeCode(type_code)) { \ + case TypeCode::kFloat16: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ + FLASHINFER_ERROR(oss.str()); \ + return false; \ + } \ + }() +#endif diff --git a/python/csrc/dispatch_utils.h b/python/csrc/dispatch_utils.h new file mode 100644 index 000000000..b34253e23 --- /dev/null +++ b/python/csrc/dispatch_utils.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include + +#include +#include + +#include "dispatch_type_code.h" +#include "generated/dispatch.inc" + +using namespace flashinfer; + +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + [&]() -> bool { \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + FLASHINFER_ERROR(oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + return __VA_ARGS__(); \ + } + +#define DISPATCH_head_dim(expr, const_expr, ...) \ + _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) + +#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("positional encoding mode", expr, \ + _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) + +#define DISPATCH_allow_fp16_qk_reduction(expr, const_expr, ...) \ + _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ + _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) + +#define DISPATCH_mask_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) + +#define DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ + [&]() -> bool { \ + if (use_logits_soft_cap) { \ + constexpr bool USE_LOGITS_SOFT_CAP = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool USE_LOGITS_SOFT_CAP = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/python/csrc/flashinfer_cascade_ops.cu b/python/csrc/flashinfer_cascade_ops.cu index 2b3270381..4527022d7 100644 --- a/python/csrc/flashinfer_cascade_ops.cu +++ b/python/csrc/flashinfer_cascade_ops.cu @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "pytorch_extension_utils.h" -std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, - torch::Tensor s_b); +void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, + at::Tensor v_merged, at::Tensor s_merged, int64_t cuda_stream); -void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other, - torch::Tensor s_other, std::optional mask = std::nullopt); +void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Tensor s_other, + std::optional mask, int64_t cuda_stream); -std::vector merge_states(torch::Tensor v, torch::Tensor s); +void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("merge_state", &merge_state, "Merge two self-attention states"); diff --git a/python/csrc/flashinfer_gemm_ops.cu b/python/csrc/flashinfer_gemm_ops.cu index 2e31ec0ce..b13a2bcd6 100644 --- a/python/csrc/flashinfer_gemm_ops.cu +++ b/python/csrc/flashinfer_gemm_ops.cu @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "pytorch_extension_utils.h" -void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, - torch::Tensor& A_scale, torch::Tensor& B_scale); +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); -void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_problems, - torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, - torch::Tensor x_ld, torch::Tensor w_ld, torch::Tensor y_ld, - torch::Tensor empty_x_data, bool weight_column_major); +void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr, + at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld, + at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM"); diff --git a/python/csrc/flashinfer_gemm_sm90_ops.cu b/python/csrc/flashinfer_gemm_sm90_ops.cu index 10d909847..b6802e424 100644 --- a/python/csrc/flashinfer_gemm_sm90_ops.cu +++ b/python/csrc/flashinfer_gemm_sm90_ops.cu @@ -13,14 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "pytorch_extension_utils.h" -void CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, torch::Tensor all_problems, - torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, - torch::Tensor x_stride, torch::Tensor weight_stride, - torch::Tensor y_stride, torch::Tensor empty_x_data, - bool weight_column_major); +void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr, + at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride, + at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, diff --git a/python/csrc/flashinfer_norm_ops.cu b/python/csrc/flashinfer_norm_ops.cu index 8c3f33850..52a103508 100644 --- a/python/csrc/flashinfer_norm_ops.cu +++ b/python/csrc/flashinfer_norm_ops.cu @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "pytorch_extension_utils.h" -void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); +void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, + int64_t cuda_stream); -void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, - double eps); +void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + int64_t cuda_stream); -void gemma_rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); +void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, + int64_t cuda_stream); -void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, - double eps); +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, + double eps, int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu new file mode 100644 index 000000000..b9ca57474 --- /dev/null +++ b/python/csrc/flashinfer_ops.cu @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "aot_extension_utils.h" + +//========== activation ========== + +void silu_and_mul(at::Tensor& out, at::Tensor& input, int cuda_stream); +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int cuda_stream); +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int cuda_stream); + +//========== cascade ========== + +void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, + at::Tensor v_merged, at::Tensor s_merged, int64_t cuda_stream); + +void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Tensor s_other, + std::optional mask, int64_t cuda_stream); + +void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged, + int64_t cuda_stream); + +//========== decode ========== + +void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, + std::optional alibi_slopes, at::Tensor o, + unsigned int layout, int window_left, float logits_soft_cap, + float sm_scale, float rope_scale, float rope_theta, + int64_t cuda_stream); + +std::vector BatchDecodeWithPagedKVCachePlan( + bool use_logits_soft_cap, unsigned int head_dim, at::Tensor empty_q_data, + at::Tensor empty_kv_data, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream); + +void BatchDecodeWithPagedKVCacheRun( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, std::optional alibi_slopes, at::Tensor o, + unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream); + +//========== gemm ========== + +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); + +void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr, + at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld, + at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major, + int64_t cuda_stream); + +//========== norm ========== + +void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, + int64_t cuda_stream); + +void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + int64_t cuda_stream); + +void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, + int64_t cuda_stream); + +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, + double eps, int64_t cuda_stream); + +//========== page ========== + +void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, + at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, + unsigned int layout, int64_t cuda_stream); + +//========== prefill ========== + +void single_prefill_with_kv_cache(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, + at::Tensor v, std::optional maybe_packed_custom_mask, + at::Tensor tmp, std::optional maybe_alibi_slopes, + at::Tensor o, unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, std::optional maybe_lse, + int64_t cuda_stream); + +std::vector BatchPrefillWithKVCachePlan( + unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, + unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream); + +void BatchPrefillWithRaggedKVCacheRun( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream); + +void BatchPrefillWithPagedKVCacheRun( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse, int64_t cuda_stream); + +//========== quantization ========== + +void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y, int64_t cuda_stream); + +void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, + const std::string& bitorder, at::Tensor y, int64_t cuda_stream); + +//========== rope ========== + +void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr, + at::Tensor offsets, unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, int64_t cuda_stream); + +void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor indptr, at::Tensor offsets, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, int64_t cuda_stream); + +void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, + float rope_scale, float rope_theta, int64_t cuda_stream); + +void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, + int64_t cuda_stream); + +void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, + at::Tensor k_rope, at::Tensor cos_cache, at::Tensor sin_cache, + at::Tensor pos_ids, bool interleave, int64_t cuda_stream); + +//========== sampling ========== + +void sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + bool deterministic, int64_t cuda_stream); + +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_p_arr, + double top_p_val, bool deterministic, int64_t cuda_stream); + +void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic, int64_t cuda_stream); + +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_min_p_arr, + double min_p_val, bool deterministic, int64_t cuda_stream); + +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, + at::Tensor samples, at::Tensor success, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, + bool deterministic, int64_t cuda_stream); + +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_p_arr, double top_p_val, + int64_t cuda_stream); + +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream); + +void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream); + +void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids, + at::Tensor uniform_samples, at::Tensor target_probs, + at::Tensor output_token_ids, at::Tensor output_accepted_token_num, + at::Tensor output_emitted_token_num, bool deterministic, + int64_t cuda_stream); + +//========== pybind11 ========== + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // activation + m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); + m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); + m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); + + // cascade + m.def("merge_state", &merge_state, "Merge two self-attention states"); + m.def("merge_state_in_place", &merge_state_in_place, + "Merge another self-attention state in-place."); + m.def("merge_states", &merge_states, "Merge multiple self-attention states"); + + // decode + m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, + "Single-request decode with KV-Cache operator"); + m.def("batch_decode_with_paged_kv_cache_plan", &BatchDecodeWithPagedKVCachePlan); + m.def("batch_decode_with_paged_kv_cache_run", &BatchDecodeWithPagedKVCacheRun); + + // gemm + m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); + m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); + + // norm + m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); + m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); + m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, + "Gemma Fused add root mean square normalization"); + + // page + m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + + // prefill + m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, + "Single-request prefill attention with KV-Cache operator"); + m.def("batch_prefill_with_kv_cache_plan", &BatchPrefillWithKVCachePlan); + m.def("batch_prefill_with_ragged_kv_cache_run", &BatchPrefillWithRaggedKVCacheRun); + m.def("batch_prefill_with_paged_kv_cache_run", &BatchPrefillWithPagedKVCacheRun); + + // quantization + m.def("packbits", &packbits, "GPU packbits operator"); + m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); + + // rope + m.def("apply_rope", &apply_rope, "Apply RoPE"); + m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); + m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, + "Apply Llama 3.1 style RoPE with positional ids"); + + // sampling + m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); + m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, + "Top-k sampling from probabilities"); + m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, + "Min-p sampling from probabilities"); + m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, + "Top-p sampling from probabilities"); + m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, + "Top-k and top-p sampling from probabilities"); + m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); + m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); + m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); + m.def("chain_speculative_sampling", &chain_speculative_sampling, + "Speculative sampling from sequence of probabilities"); +} diff --git a/python/csrc/flashinfer_page_ops.cu b/python/csrc/flashinfer_page_ops.cu index aacaa4859..d78d4ac00 100644 --- a/python/csrc/flashinfer_page_ops.cu +++ b/python/csrc/flashinfer_page_ops.cu @@ -13,13 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "pytorch_extension_utils.h" -void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor batch_indices, torch::Tensor positions, - torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, - torch::Tensor kv_indices, torch::Tensor kv_indptr, - torch::Tensor kv_last_page_len, unsigned int layout); +void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, + at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, + unsigned int layout, int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); diff --git a/python/csrc/flashinfer_quantization_ops.cu b/python/csrc/flashinfer_quantization_ops.cu index 7f2886091..d23867bfb 100644 --- a/python/csrc/flashinfer_quantization_ops.cu +++ b/python/csrc/flashinfer_quantization_ops.cu @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "pytorch_extension_utils.h" -torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); +void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y, int64_t cuda_stream); -torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, - torch::Tensor output_indptr, const std::string& bitorder); +void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, + const std::string& bitorder, at::Tensor y, int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("packbits", &packbits, "GPU packbits operator"); diff --git a/python/csrc/flashinfer_rope_ops.cu b/python/csrc/flashinfer_rope_ops.cu index 369997c4b..41fb284a2 100644 --- a/python/csrc/flashinfer_rope_ops.cu +++ b/python/csrc/flashinfer_rope_ops.cu @@ -13,34 +13,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include - #include -void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta); +#include "pytorch_extension_utils.h" + +void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr, + at::Tensor offsets, unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, int64_t cuda_stream); -void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, - unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length); +void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor indptr, at::Tensor offsets, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, int64_t cuda_stream); -void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta); +void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, + float rope_scale, float rope_theta, int64_t cuda_stream); -void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, - unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length); +void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, + int64_t cuda_stream); -void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor cos_cache, - torch::Tensor sin_cache, torch::Tensor pos_ids, - bool interleave); +void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, + at::Tensor k_rope, at::Tensor cos_cache, at::Tensor sin_cache, + at::Tensor pos_ids, bool interleave, int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("apply_rope", &apply_rope, "Apply RoPE"); diff --git a/python/csrc/flashinfer_sampling_ops.cu b/python/csrc/flashinfer_sampling_ops.cu index 1437a67b0..3ae155710 100644 --- a/python/csrc/flashinfer_sampling_ops.cu +++ b/python/csrc/flashinfer_sampling_ops.cu @@ -13,45 +13,46 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "pytorch_extension_utils.h" -torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - bool deterministic); +void sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + bool deterministic, int64_t cuda_stream); -std::vector top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_p_arr, - double top_p_val, bool deterministic); +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_p_arr, + double top_p_val, bool deterministic, int64_t cuda_stream); -std::vector top_k_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic); +void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic, int64_t cuda_stream); -std::vector min_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_min_p_arr, - double min_p_val, bool deterministic); +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_min_p_arr, + double min_p_val, bool deterministic, int64_t cuda_stream); -std::vector top_k_top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic); +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, + at::Tensor samples, at::Tensor success, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, + bool deterministic, int64_t cuda_stream); -torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, - double top_p_val); +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_p_arr, double top_p_val, + int64_t cuda_stream); -torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, - unsigned int top_k_val); +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream); -torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, - unsigned int top_k_val); +void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream); -torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, - torch::Tensor uniform_samples, torch::Tensor target_probs, - torch::Tensor output_accepted_token_num, - torch::Tensor output_emitted_token_num, - bool deterministic); +void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids, + at::Tensor uniform_samples, at::Tensor target_probs, + at::Tensor output_token_ids, at::Tensor output_accepted_token_num, + at::Tensor output_emitted_token_num, bool deterministic, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index ce593bef5..78779fe5d 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -17,25 +17,22 @@ #include "pytorch_extension_utils.h" +using namespace flashinfer; using namespace flashinfer::group_gemm; -void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_problems, - torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, - torch::Tensor x_ld, torch::Tensor w_ld, torch::Tensor y_ld, - torch::Tensor empty_x_data, bool weight_column_major) { +void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr, + at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld, + at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major, + int64_t cuda_stream) { unsigned int batch_size = x_ptr.size(0); - auto device = workspace_buffer.device(); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_x_data.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::value; auto status = CutlassSegmentGEMMRun( workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0), all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), - x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, - torch_current_stream); + x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, stream); TORCH_CHECK(status == cudaSuccess, "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc/group_gemm_sm90.cu b/python/csrc/group_gemm_sm90.cu index 9c4eaf9a9..3710cf2fd 100644 --- a/python/csrc/group_gemm_sm90.cu +++ b/python/csrc/group_gemm_sm90.cu @@ -17,20 +17,18 @@ #include "pytorch_extension_utils.h" +using namespace flashinfer; using namespace flashinfer::group_gemm; -void CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, torch::Tensor all_problems, - torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, - torch::Tensor x_stride, torch::Tensor weight_stride, - torch::Tensor y_stride, torch::Tensor empty_x_data, - bool weight_column_major) { +void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr, + at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride, + at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major, + int64_t cuda_stream) { unsigned int batch_size = x_ptr.size(0); auto device = float_workspace_buffer.device(); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_x_data.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::value; auto status = CutlassSegmentGEMMSM90Run( @@ -39,7 +37,7 @@ void CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, int_workspace_buffer.data_ptr(), int_workspace_buffer.element_size() * int_workspace_buffer.size(0), all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), x_stride.data_ptr(), - weight_stride.data_ptr(), y_stride.data_ptr(), weight_column_major, torch_current_stream); + weight_stride.data_ptr(), y_stride.data_ptr(), weight_column_major, stream); TORCH_CHECK(status == cudaSuccess, "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc/norm.cu b/python/csrc/norm.cu index be5754202..e1405dce1 100644 --- a/python/csrc/norm.cu +++ b/python/csrc/norm.cu @@ -13,13 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include "pytorch_extension_utils.h" using namespace flashinfer; -void rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight, double eps) { +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, + int64_t cuda_stream) { CHECK_INPUT(input); CHECK_INPUT(weight); auto device = input.device(); @@ -32,21 +34,19 @@ void rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight, CHECK_EQ(output.size(0), batch_size); CHECK_EQ(output.size(1), hidden_size); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { - cudaError_t status = norm::RMSNorm(static_cast(input.data_ptr()), - static_cast(weight.data_ptr()), - static_cast(output.data_ptr()), batch_size, - hidden_size, eps, torch_current_stream); + cudaError_t status = norm::RMSNorm( + static_cast(input.data_ptr()), static_cast(weight.data_ptr()), + static_cast(output.data_ptr()), batch_size, hidden_size, eps, stream); TORCH_CHECK(status == cudaSuccess, "RMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; }); } -void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, - double eps) { +void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + int64_t cuda_stream) { CHECK_INPUT(input); CHECK_INPUT(residual); CHECK_INPUT(weight); @@ -62,20 +62,19 @@ void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Ten unsigned int batch_size = input.size(0); unsigned int hidden_size = input.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { - cudaError_t status = norm::FusedAddRMSNorm(static_cast(input.data_ptr()), - static_cast(residual.data_ptr()), - static_cast(weight.data_ptr()), batch_size, - hidden_size, eps, torch_current_stream); + cudaError_t status = norm::FusedAddRMSNorm( + static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, stream); TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; }); } -void gemma_rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight, double eps) { +void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, + int64_t cuda_stream) { CHECK_INPUT(input); CHECK_INPUT(weight); auto device = input.device(); @@ -88,21 +87,19 @@ void gemma_rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& w CHECK_EQ(output.size(0), batch_size); CHECK_EQ(output.size(1), hidden_size); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { - cudaError_t status = norm::GemmaRMSNorm(static_cast(input.data_ptr()), - static_cast(weight.data_ptr()), - static_cast(output.data_ptr()), batch_size, - hidden_size, eps, torch_current_stream); + cudaError_t status = norm::GemmaRMSNorm( + static_cast(input.data_ptr()), static_cast(weight.data_ptr()), + static_cast(output.data_ptr()), batch_size, hidden_size, eps, stream); TORCH_CHECK(status == cudaSuccess, "GemmaRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; }); } -void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, - double eps) { +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, + double eps, int64_t cuda_stream) { CHECK_INPUT(input); CHECK_INPUT(residual); CHECK_INPUT(weight); @@ -118,13 +115,11 @@ void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torc unsigned int batch_size = input.size(0); unsigned int hidden_size = input.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { cudaError_t status = norm::GemmaFusedAddRMSNorm( static_cast(input.data_ptr()), static_cast(residual.data_ptr()), - static_cast(weight.data_ptr()), batch_size, hidden_size, eps, - torch_current_stream); + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, stream); TORCH_CHECK(status == cudaSuccess, "GemmaFusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 2a2bf136e..644a7dc62 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -19,11 +19,10 @@ using namespace flashinfer; -void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor batch_indices, torch::Tensor positions, - torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, - torch::Tensor kv_indices, torch::Tensor kv_indptr, - torch::Tensor kv_last_page_len, unsigned int layout) { +void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, + at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, + unsigned int layout, int64_t cuda_stream) { CHECK_LAST_DIM_CONTIGUOUS(append_key); CHECK_LAST_DIM_CONTIGUOUS(append_value); CHECK_INPUT(batch_indices); @@ -48,11 +47,6 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, CHECK_EQ(kv_indptr.size(0), batch_size + 1); CHECK_EQ(batch_indices.size(0), nnz); CHECK_EQ(positions.size(0), nnz); - CHECK_EQ(batch_indices.scalar_type(), torch::kInt32); - CHECK_EQ(positions.scalar_type(), torch::kInt32); - CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(kv_indices.scalar_type(), torch::kInt32); - CHECK_EQ(kv_last_page_len.scalar_type(), torch::kInt32); auto device = append_key.device(); CHECK_EQ(append_key.device(), device); CHECK_EQ(append_value.device(), device); @@ -92,12 +86,10 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, CHECK_EQ(append_key.size(2), head_dim); CHECK_EQ(append_value.size(1), num_heads); CHECK_EQ(append_value.size(2), head_dim); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto kv_scalar_dtype = paged_k_cache.scalar_type(); + cudaStream_t stream = reinterpret_cast(cuda_stream); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_scalar_dtype, c_type, [&] { paged_kv_t paged_kv( num_heads, page_size, head_dim, batch_size, kv_layout, @@ -105,12 +97,12 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, static_cast(paged_v_cache.data_ptr()), kv_cache_strides, static_cast(kv_indices.data_ptr()), static_cast(kv_indptr.data_ptr()), static_cast(kv_last_page_len.data_ptr())); - cudaError_t status = AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), - static_cast(append_value.data_ptr()), - static_cast(batch_indices.data_ptr()), - static_cast(positions.data_ptr()), nnz, - append_k_stride_n, append_k_stride_h, append_v_stride_n, - append_v_stride_h, torch_current_stream); + cudaError_t status = + AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), + static_cast(append_value.data_ptr()), + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), nnz, append_k_stride_n, + append_k_stride_h, append_v_stride_n, append_v_stride_h, stream); TORCH_CHECK(status == cudaSuccess, "AppendPagedKVCache failed with error: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 7f67bdcdc..825d6431f 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -14,14 +14,11 @@ * limitations under the License. */ #pragma once -#include -#include +// NOTE(Zihao): only include minimal headers to accelerate compilation #include #include #include -#include - -using namespace flashinfer; +#include #ifdef FLASHINFER_ENABLE_BF16 #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ @@ -192,7 +189,7 @@ using namespace flashinfer; return __VA_ARGS__(); \ } -inline void check_shape(const torch::Tensor& a, const torch::Tensor& b, const char* a_name, +inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) { TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim()); @@ -230,7 +227,7 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { #define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) -inline bool is_float8_tensor(const torch::Tensor& tensor) { +inline bool is_float8_tensor(const at::Tensor& tensor) { return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2; } diff --git a/python/csrc/quantization.cu b/python/csrc/quantization.cu index 3471f792a..04a36b126 100644 --- a/python/csrc/quantization.cu +++ b/python/csrc/quantization.cu @@ -19,32 +19,23 @@ using namespace flashinfer; -torch::Tensor packbits(torch::Tensor x, const std::string& bitorder) { +void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y, int64_t cuda_stream) { CHECK_INPUT(x); auto device = x.device(); TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'"); - x = x.to(torch::kBool); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); int64_t num_elements = x.numel(); - int64_t num_output_elements = (num_elements + 7) / 8; - - auto y = torch::empty({num_output_elements}, x.options().dtype(torch::kUInt8)); - + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = quantization::PackBits( static_cast(x.data_ptr()), static_cast(y.data_ptr()), num_elements, - bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle, - torch_current_stream); + bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle, stream); TORCH_CHECK(status == cudaSuccess, "PackBits failed with error code " + std::string(cudaGetErrorString(status))); - return y; } -torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, - torch::Tensor output_indptr, const std::string& bitorder) { +void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, + const std::string& bitorder, at::Tensor y, int64_t cuda_stream) { CHECK_INPUT(x); CHECK_INPUT(input_indptr); CHECK_INPUT(output_indptr); @@ -54,17 +45,11 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'"); unsigned int batch_size = input_indptr.size(0) - 1; CHECK_EQ(output_indptr.size(0), batch_size + 1); - input_indptr = input_indptr.to(torch::kInt32); - output_indptr = output_indptr.to(torch::kInt32); - int64_t output_nnz = output_indptr[batch_size].item(); - auto y = torch::empty({output_nnz}, x.options().dtype(torch::kUInt8)); - const at::cuda::OptionalCUDAGuard device_guard(device); + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = quantization::SegmentPackBits( static_cast(x.data_ptr()), static_cast(y.data_ptr()), static_cast(input_indptr.data_ptr()), static_cast(output_indptr.data_ptr()), batch_size, - bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle, - c10::cuda::getCurrentCUDAStream(device.index())); - return y; + bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle, stream); } diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index dc21f9758..a1018dbca 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -19,9 +19,9 @@ using namespace flashinfer; -void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta) { +void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr, + at::Tensor offsets, unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, int64_t cuda_stream) { CHECK_LAST_DIM_CONTIGUOUS(q); CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_INPUT(indptr); @@ -48,11 +48,8 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T size_t q_rope_stride_h = q_rope.stride(1); size_t k_rope_stride_n = k_rope.stride(0); size_t k_rope_stride_h = k_rope.stride(1); - indptr = indptr.to(torch::kInt32); - offsets = offsets.to(torch::kInt32); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { cudaError_t status = BatchQKApplyRotary( static_cast(q.data_ptr()), static_cast(k.data_ptr()), @@ -60,16 +57,16 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), batch_size, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, - interleave, rope_scale, rope_theta, torch_current_stream); + interleave, rope_scale, rope_theta, stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " + std::string(cudaGetErrorString(status))); return true; }); } -void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta) { +void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, + float rope_scale, float rope_theta, int64_t cuda_stream) { CHECK_LAST_DIM_CONTIGUOUS(q); CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_INPUT(pos_ids); @@ -92,27 +89,24 @@ void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, size_t q_rope_stride_h = q_rope.stride(1); size_t k_rope_stride_n = k_rope.stride(0); size_t k_rope_stride_h = k_rope.stride(1); - pos_ids = pos_ids.to(torch::kInt32); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { cudaError_t status = BatchQKApplyRotaryPosIds( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, - k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); + k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryPosIds failed with error code " + std::string(cudaGetErrorString(status))); return true; }); } -void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor cos_cache, - torch::Tensor sin_cache, torch::Tensor pos_ids, - bool interleave) { +void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, + at::Tensor k_rope, at::Tensor cos_cache, at::Tensor sin_cache, + at::Tensor pos_ids, bool interleave, int64_t cuda_stream) { CHECK_LAST_DIM_CONTIGUOUS(q); CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_INPUT(cos_cache); @@ -131,8 +125,6 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T CHECK_EQ(q.size(2), k.size(2)); unsigned int rotary_dim = cos_cache.size(1); CHECK_EQ(sin_cache.size(1), rotary_dim); - CHECK_EQ(cos_cache.dtype(), torch::kFloat32); - CHECK_EQ(sin_cache.dtype(), torch::kFloat32); unsigned int num_qo_heads = q.size(1); unsigned int num_kv_heads = k.size(1); unsigned int head_dim = q.size(2); @@ -145,10 +137,9 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T size_t q_rope_stride_h = q_rope.stride(1); size_t k_rope_stride_n = k_rope.stride(0); size_t k_rope_stride_h = k_rope.stride(1); - pos_ids = pos_ids.to(torch::kInt32); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); + cudaStream_t torch_current_stream(nullptr); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( static_cast(q.data_ptr()), static_cast(k.data_ptr()), @@ -156,7 +147,7 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T static_cast(cos_cache.data_ptr()), static_cast(sin_cache.data_ptr()), static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, - k_rope_stride_n, k_rope_stride_h, interleave, torch_current_stream); + k_rope_stride_n, k_rope_stride_h, interleave, stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status))); @@ -164,11 +155,10 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T }); } -void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, - unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length) { +void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor indptr, at::Tensor offsets, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, int64_t cuda_stream) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -195,11 +185,8 @@ void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, size_t q_rope_stride_h = q_rope.stride(1); size_t k_rope_stride_n = k_rope.stride(0); size_t k_rope_stride_h = k_rope.stride(1); - indptr = indptr.to(torch::kInt32); - offsets = offsets.to(torch::kInt32); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { cudaError_t status = BatchQKApplyLlama31Rotary( static_cast(q.data_ptr()), static_cast(k.data_ptr()), @@ -208,18 +195,18 @@ void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, batch_size, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, - torch_current_stream); + stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " + std::string(cudaGetErrorString(status))); return true; }); } -void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, - unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length) { +void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, + int64_t cuda_stream) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(pos_ids); @@ -242,10 +229,9 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor size_t q_rope_stride_h = q_rope.stride(1); size_t k_rope_stride_n = k_rope.stride(0); size_t k_rope_stride_h = k_rope.stride(1); - pos_ids = pos_ids.to(torch::kInt32); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + cudaStream_t stream = reinterpret_cast(cuda_stream); + cudaStream_t torch_current_stream(nullptr); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { cudaError_t status = BatchQKApplyLlama31RotaryPosIds( static_cast(q.data_ptr()), static_cast(k.data_ptr()), @@ -253,7 +239,7 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, - high_freq_factor, old_context_length, torch_current_stream); + high_freq_factor, old_context_length, stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryPosIds failed with error code " + std::string(cudaGetErrorString(status))); return true; diff --git a/python/csrc/runtime_utils.h b/python/csrc/runtime_utils.h new file mode 100644 index 000000000..cc4c3f707 --- /dev/null +++ b/python/csrc/runtime_utils.h @@ -0,0 +1,18 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#define FLASHINFER_DLL __attribute__((visibility("default"))) diff --git a/python/csrc/sampling.cu b/python/csrc/sampling.cu index d7068864f..0f0021a16 100644 --- a/python/csrc/sampling.cu +++ b/python/csrc/sampling.cu @@ -19,8 +19,8 @@ using namespace flashinfer; -torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - bool deterministic) { +void sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + bool deterministic, int64_t cuda_stream) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); auto device = probs.device(); @@ -30,26 +30,18 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam CHECK_EQ(probs.size(0), uniform_samples.size(0)); unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - - cudaError_t status = sampling::SamplingFromProb(static_cast(probs.data_ptr()), - static_cast(uniform_samples.data_ptr()), - static_cast(samples.data_ptr()), batch_size, - vocab_size, deterministic, torch_current_stream); + cudaStream_t stream = reinterpret_cast(cuda_stream); + cudaError_t status = sampling::SamplingFromProb( + static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), + static_cast(samples.data_ptr()), batch_size, vocab_size, deterministic, stream); TORCH_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); - return samples; } -std::vector top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_p_arr, - double top_p_val, bool deterministic) { +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_p_arr, + double top_p_val, bool deterministic, int64_t cuda_stream) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); auto device = probs.device(); @@ -61,37 +53,20 @@ std::vector top_p_sampling_from_probs(torch::Tensor probs, unsigned int vocab_size = probs.size(1); unsigned int max_top_p_rounds = uniform_samples.size(0); bool has_top_p_arr = maybe_top_p_arr.has_value(); - auto top_p_arr = maybe_top_p_arr.value_or(torch::empty({0}, torch::dtype(torch::kFloat32))); - if (has_top_p_arr) { - CHECK_INPUT(top_p_arr); - CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,) - CHECK_EQ(top_p_arr.size(0), batch_size); - CHECK_EQ(top_p_arr.device(), device); - } - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - top_p_arr = top_p_arr.to(torch::kFloat32); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device)); + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::TopPSamplingFromProb( static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), static_cast(samples.data_ptr()), static_cast(success.data_ptr()), - has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, batch_size, top_p_val, - vocab_size, max_top_p_rounds, deterministic, torch_current_stream); + has_top_p_arr ? static_cast(maybe_top_p_arr->data_ptr()) : nullptr, batch_size, + top_p_val, vocab_size, max_top_p_rounds, deterministic, stream); TORCH_CHECK(status == cudaSuccess, "TopPSamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); - - return {samples, success}; } -std::vector top_k_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic) { +void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic, int64_t cuda_stream) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); auto device = probs.device(); @@ -103,37 +78,20 @@ std::vector top_k_sampling_from_probs(torch::Tensor probs, unsigned int vocab_size = probs.size(1); unsigned int max_top_k_rounds = uniform_samples.size(0); bool has_top_k_arr = maybe_top_k_arr.has_value(); - auto top_k_arr = maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); - if (has_top_k_arr) { - CHECK_INPUT(top_k_arr); - CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) - CHECK_EQ(top_k_arr.size(0), batch_size); - CHECK_EQ(top_k_arr.device(), device); - } - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - top_k_arr = top_k_arr.to(torch::kInt32); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device)); + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::TopKSamplingFromProb( static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), static_cast(samples.data_ptr()), static_cast(success.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, batch_size, top_k_val, - vocab_size, max_top_k_rounds, deterministic, torch_current_stream); + has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, + top_k_val, vocab_size, max_top_k_rounds, deterministic, stream); TORCH_CHECK(status == cudaSuccess, "TopKSamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); - - return {samples, success}; } -std::vector min_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_min_p_arr, - double min_p_val, bool deterministic) { +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_min_p_arr, + double min_p_val, bool deterministic, int64_t cuda_stream) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); auto device = probs.device(); @@ -145,37 +103,22 @@ std::vector min_p_sampling_from_probs(torch::Tensor probs, unsigned int max_rounds = uniform_samples.size(0); CHECK_EQ(uniform_samples.size(1), batch_size); bool has_min_p_arr = maybe_min_p_arr.has_value(); - auto min_p_arr = maybe_min_p_arr.value_or(torch::empty({0}, torch::dtype(torch::kFloat32))); - if (has_min_p_arr) { - CHECK_INPUT(min_p_arr); - CHECK_DIM(1, min_p_arr); // min_p_arr: (batch_size,) - CHECK_EQ(min_p_arr.size(0), batch_size); - CHECK_EQ(min_p_arr.device(), device); - } - min_p_arr = min_p_arr.to(torch::kFloat32); - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device)); + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::MinPSamplingFromProb( static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), - has_min_p_arr ? static_cast(min_p_arr.data_ptr()) : nullptr, + has_min_p_arr ? static_cast(maybe_min_p_arr->data_ptr()) : nullptr, static_cast(samples.data_ptr()), static_cast(success.data_ptr()), batch_size, - min_p_val, vocab_size, max_rounds, deterministic, torch_current_stream); + min_p_val, vocab_size, max_rounds, deterministic, stream); TORCH_CHECK(status == cudaSuccess, "MinPSamplingFromProb failed with error code " + std::string(cudaGetErrorString(status))); - - return {samples, success}; } -std::vector top_k_top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic) { +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, + at::Tensor samples, at::Tensor success, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, + bool deterministic, int64_t cuda_stream) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); auto device = probs.device(); @@ -187,146 +130,83 @@ std::vector top_k_top_p_sampling_from_probs( unsigned int max_rounds = uniform_samples.size(0); CHECK_EQ(uniform_samples.size(1), batch_size); bool has_top_k_arr = maybe_top_k_arr.has_value(); - auto top_k_arr = maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); - if (has_top_k_arr) { - CHECK_INPUT(top_k_arr); - CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) - CHECK_EQ(top_k_arr.size(0), batch_size); - CHECK_EQ(top_k_arr.device(), device); - } - top_k_arr = top_k_arr.to(torch::kInt32); bool has_top_p_arr = maybe_top_p_arr.has_value(); - auto top_p_arr = maybe_top_p_arr.value_or(torch::empty({0}, torch::dtype(torch::kFloat32))); - if (has_top_p_arr) { - CHECK_INPUT(top_p_arr); - CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,) - CHECK_EQ(top_p_arr.size(0), batch_size); - CHECK_EQ(top_p_arr.device(), device); - } - top_p_arr = top_p_arr.to(torch::kFloat32); - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device)); + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::TopKTopPSamplingFromProb( static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, - has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, + has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, + has_top_p_arr ? static_cast(maybe_top_p_arr->data_ptr()) : nullptr, static_cast(samples.data_ptr()), static_cast(success.data_ptr()), batch_size, - top_k_val, top_p_val, vocab_size, max_rounds, deterministic, torch_current_stream); + top_k_val, top_p_val, vocab_size, max_rounds, deterministic, stream); TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); - - return {samples, success}; } -torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, - double top_p_val) { +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_p_arr, double top_p_val, + int64_t cuda_stream) { CHECK_INPUT(probs); auto device = probs.device(); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); bool has_top_p_arr = maybe_top_p_arr.has_value(); - auto top_p_arr = maybe_top_p_arr.value_or(torch::empty({0}, torch::dtype(torch::kFloat32))); - if (has_top_p_arr) { - CHECK_INPUT(top_p_arr); - CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,) - CHECK_EQ(top_p_arr.size(0), batch_size); - CHECK_EQ(top_p_arr.device(), device); - } - top_p_arr = top_p_arr.to(torch::kFloat32); - probs = probs.to(torch::kFloat32); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto renorm_probs = - torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device)); + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::TopPRenormProb( static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), - has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, batch_size, top_p_val, - vocab_size, torch_current_stream); + has_top_p_arr ? static_cast(maybe_top_p_arr->data_ptr()) : nullptr, batch_size, + top_p_val, vocab_size, stream); TORCH_CHECK(status == cudaSuccess, "TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status))); - return renorm_probs; } -torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, - unsigned int top_k_val) { +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream) { CHECK_INPUT(probs); auto device = probs.device(); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); bool has_top_k_arr = maybe_top_k_arr.has_value(); - auto top_k_arr = maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); - if (has_top_k_arr) { - CHECK_INPUT(top_k_arr); - CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) - CHECK_EQ(top_k_arr.size(0), batch_size); - CHECK_EQ(top_k_arr.device(), device); - } - top_k_arr = top_k_arr.to(torch::kInt32); - probs = probs.to(torch::kFloat32); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto renorm_probs = - torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device)); + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::TopKRenormProb( static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, batch_size, top_k_val, - vocab_size, torch_current_stream); + has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, + top_k_val, vocab_size, stream); TORCH_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status))); - return renorm_probs; } -torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, - unsigned int top_k_val) { +void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream) { CHECK_INPUT(logits); auto device = logits.device(); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) unsigned int batch_size = logits.size(0); unsigned int vocab_size = logits.size(1); bool has_top_k_arr = maybe_top_k_arr.has_value(); - auto top_k_arr = maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); - if (has_top_k_arr) { - CHECK_INPUT(top_k_arr); - CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) - CHECK_EQ(top_k_arr.size(0), batch_size); - CHECK_EQ(top_k_arr.device(), device); - } - top_k_arr = top_k_arr.to(torch::kInt32); - logits = logits.to(torch::kFloat32); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto mask_logits = - torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device)); + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::TopKMaskLogits( static_cast(logits.data_ptr()), static_cast(mask_logits.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, batch_size, top_k_val, - vocab_size, torch_current_stream); + has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, + top_k_val, vocab_size, stream); TORCH_CHECK(status == cudaSuccess, "TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status))); - return mask_logits; } -torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, - torch::Tensor uniform_samples, torch::Tensor target_probs, - torch::Tensor output_accepted_token_num, - torch::Tensor output_emitted_token_num, - bool deterministic) { +void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids, + at::Tensor uniform_samples, at::Tensor target_probs, + at::Tensor output_token_ids, at::Tensor output_accepted_token_num, + at::Tensor output_emitted_token_num, bool deterministic, + int64_t cuda_stream) { CHECK_INPUT(draft_probs); CHECK_INPUT(draft_token_ids); CHECK_INPUT(uniform_samples); @@ -351,26 +231,15 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso CHECK_EQ(batch_size, output_accepted_token_num.size(0)); CHECK_EQ(batch_size, output_emitted_token_num.size(0)); - draft_probs = draft_probs.to(torch::kFloat32); - draft_token_ids = draft_token_ids.to(torch::kInt32); - uniform_samples = uniform_samples.to(torch::kFloat32); - target_probs = target_probs.to(torch::kFloat32); - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1}, - torch::dtype(torch::kInt32).device(device)); - + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::ChainSpeculativeSampling( static_cast(draft_probs.data_ptr()), static_cast(draft_token_ids.data_ptr()), static_cast(uniform_samples.data_ptr()), static_cast(target_probs.data_ptr()), static_cast(output_token_ids.data_ptr()), static_cast(output_accepted_token_num.data_ptr()), static_cast(output_emitted_token_num.data_ptr()), batch_size, num_speculate_tokens, - vocab_size, deterministic, torch_current_stream); + vocab_size, deterministic, stream); TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " + std::string(cudaGetErrorString(status))); - - return output_token_ids; } diff --git a/python/csrc_aot/single_decode.cu b/python/csrc/single_decode.cu similarity index 81% rename from python/csrc_aot/single_decode.cu rename to python/csrc/single_decode.cu index 74a64e49c..60f9114bb 100644 --- a/python/csrc_aot/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -17,10 +17,10 @@ #include #include +#include #include -#include "flashinfer/pos_enc.cuh" -#include "pytorch_extension_utils.h" +#include "aot_extension_utils.h" namespace flashinfer { @@ -30,12 +30,13 @@ cudaError_t SingleDecodeWithKVCacheDispatched(typename AttentionVariant::ParamsT cudaStream_t stream); } // namespace flashinfer -torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, - std::optional alibi_slopes, - unsigned int layout, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta) { +using namespace flashinfer; + +void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, + std::optional alibi_slopes, at::Tensor o, + unsigned int layout, int window_left, float logits_soft_cap, + float sm_scale, float rope_scale, float rope_theta, + int64_t cuda_stream) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); @@ -62,9 +63,6 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc kv_len = k.size(1); } CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - const at::cuda::CUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q); TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); @@ -72,7 +70,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc auto kv_scalar_type = k.scalar_type(); constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; - + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { using DTypeQ = q_type; using DTypeKV = kv_type; @@ -92,13 +90,11 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc cudaError_t status = flashinfer::SingleDecodeWithKVCacheDispatched( - params, static_cast(tmp.data_ptr()), torch_current_stream); + params, static_cast(tmp.data_ptr()), stream); TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); return true; }); }); }); - - return o; } diff --git a/python/csrc_aot/single_prefill.cu b/python/csrc/single_prefill.cu similarity index 84% rename from python/csrc_aot/single_prefill.cu rename to python/csrc/single_prefill.cu index 6d0526a83..417acf232 100644 --- a/python/csrc_aot/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -13,15 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include - #include #include #include +#include #include -#include "pytorch_extension_utils.h" +#include "aot_extension_utils.h" namespace flashinfer { @@ -33,12 +31,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params } // namespace flashinfer -torch::Tensor single_prefill_with_kv_cache( - unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, - std::optional maybe_packed_custom_mask, torch::Tensor tmp, - std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse) { +using namespace flashinfer; + +void single_prefill_with_kv_cache(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, + at::Tensor v, std::optional maybe_packed_custom_mask, + at::Tensor tmp, std::optional maybe_alibi_slopes, + at::Tensor o, unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, std::optional maybe_lse, + int64_t cuda_stream) { auto device = q.device(); unsigned int head_dim = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; @@ -57,14 +58,10 @@ torch::Tensor single_prefill_with_kv_cache( kv_stride_h = k.stride(0); kv_stride_n = k.stride(1); } - const at::cuda::CUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q, q.options()); if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == qo_len, lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1)); - TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; @@ -74,6 +71,7 @@ torch::Tensor single_prefill_with_kv_cache( auto q_scalar_type = q.scalar_type(); auto kv_scalar_type = k.scalar_type(); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { using DTypeQ = q_type; using DTypeKV = kv_type; @@ -103,7 +101,7 @@ torch::Tensor single_prefill_with_kv_cache( flashinfer::SinglePrefillWithKVCacheDispatched( - params, static_cast(tmp.data_ptr()), torch_current_stream); + params, static_cast(tmp.data_ptr()), stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); @@ -112,6 +110,4 @@ torch::Tensor single_prefill_with_kv_cache( }); }); }); - - return o; } diff --git a/python/csrc_aot/flashinfer_ops.cu b/python/csrc_aot/flashinfer_ops.cu deleted file mode 100644 index ec519c788..000000000 --- a/python/csrc_aot/flashinfer_ops.cu +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -//========== activation ========== - -void silu_and_mul(torch::Tensor& out, torch::Tensor& input); -void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); -void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); - -//========== cascade ========== - -std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, - torch::Tensor s_b); - -void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other, - torch::Tensor s_other, std::optional mask = std::nullopt); - -std::vector merge_states(torch::Tensor v, torch::Tensor s); - -//========== decode ========== - -torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, - std::optional alibi_slopes, - unsigned int layout, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta); - -std::vector BatchDecodeWithPagedKVCachePlan( - bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, - torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, - torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); - -torch::Tensor BatchDecodeWithPagedKVCacheRun( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, - unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse); - -//========== gemm ========== - -void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, - torch::Tensor& A_scale, torch::Tensor& B_scale); - -void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_problems, - torch::Tensor x_ptr, torch::Tensor w_ptr, torch::Tensor y_ptr, - torch::Tensor x_ld, torch::Tensor w_ld, torch::Tensor y_ld, - torch::Tensor empty_x_data, bool weight_column_major); - -//========== norm ========== - -void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); - -void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, - double eps); - -void gemma_rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); - -void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, - double eps); - -//========== page ========== - -void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor batch_indices, torch::Tensor positions, - torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, - torch::Tensor kv_indices, torch::Tensor kv_indptr, - torch::Tensor kv_last_page_len, unsigned int layout); - -//========== prefill ========== - -torch::Tensor single_prefill_with_kv_cache( - unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, - std::optional maybe_packed_custom_mask, torch::Tensor tmp, - std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse); - -std::vector BatchPrefillWithKVCachePlan( - unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); - -torch::Tensor BatchPrefillWithRaggedKVCacheRun( - unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse); - -torch::Tensor BatchPrefillWithPagedKVCacheRun( - unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, - std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, - torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, - unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse); - -//========== quantization ========== - -torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); - -torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, - torch::Tensor output_indptr, const std::string& bitorder); - -//========== rope ========== - -void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta); - -void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, - unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length); - -void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta); - -void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, - unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length); - -void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor cos_cache, - torch::Tensor sin_cache, torch::Tensor pos_ids, - bool interleave); - -//========== sampling ========== - -torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - bool deterministic); - -std::vector top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_p_arr, - double top_p_val, bool deterministic); - -std::vector top_k_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic); - -std::vector min_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_min_p_arr, - double min_p_val, bool deterministic); - -std::vector top_k_top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic); - -torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, - double top_p_val); - -torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, - unsigned int top_k_val); - -torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, - unsigned int top_k_val); - -torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, - torch::Tensor uniform_samples, torch::Tensor target_probs, - torch::Tensor output_accepted_token_num, - torch::Tensor output_emitted_token_num, - bool deterministic); - -//========== pybind11 ========== - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // activation - m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); - m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); - m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); - - // cascade - m.def("merge_state", &merge_state, "Merge two self-attention states"); - m.def("merge_state_in_place", &merge_state_in_place, - "Merge another self-attention state in-place."); - m.def("merge_states", &merge_states, "Merge multiple self-attention states"); - - // decode - m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, - "Single-request decode with KV-Cache operator"); - m.def("batch_decode_with_paged_kv_cache_plan", &BatchDecodeWithPagedKVCachePlan); - m.def("batch_decode_with_paged_kv_cache_run", &BatchDecodeWithPagedKVCacheRun); - - // gemm - m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); - m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); - - // norm - m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, - "Gemma Fused add root mean square normalization"); - - // page - m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); - - // prefill - m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, - "Single-request prefill attention with KV-Cache operator"); - m.def("batch_prefill_with_kv_cache_plan", &BatchPrefillWithKVCachePlan); - m.def("batch_prefill_with_ragged_kv_cache_run", &BatchPrefillWithRaggedKVCacheRun); - m.def("batch_prefill_with_paged_kv_cache_run", &BatchPrefillWithPagedKVCacheRun); - - // quantization - m.def("packbits", &packbits, "GPU packbits operator"); - m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); - - // rope - m.def("apply_rope", &apply_rope, "Apply RoPE"); - m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); - m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); - m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, - "Apply Llama 3.1 style RoPE with positional ids"); - - // sampling - m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); - m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, - "Top-k sampling from probabilities"); - m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, - "Min-p sampling from probabilities"); - m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, - "Top-p sampling from probabilities"); - m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, - "Top-k and top-p sampling from probabilities"); - m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); - m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); - m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); - m.def("chain_speculative_sampling", &chain_speculative_sampling, - "Speculative sampling from sequence of probabilities"); -} diff --git a/python/csrc_aot/pytorch_extension_utils.h b/python/csrc_aot/pytorch_extension_utils.h deleted file mode 100644 index d7545ce5b..000000000 --- a/python/csrc_aot/pytorch_extension_utils.h +++ /dev/null @@ -1,275 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include -#include -#include - -#include -#include - -#include "generated/dispatch.inc" - -using namespace flashinfer; - -#ifdef FLASHINFER_ENABLE_BF16 -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - case at::ScalarType::Half: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::BFloat16: { \ - using c_type = nv_bfloat16; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() -#else -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - case at::ScalarType::Half: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() -#endif - -#ifdef FLASHINFER_ENABLE_FP8 -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - case at::ScalarType::Float8_e4m3fn: { \ - using c_type = __nv_fp8_e4m3; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::Float8_e5m2: { \ - using c_type = __nv_fp8_e5m2; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() -#else -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - }() -#endif - -#if defined(FLASHINFER_ENABLE_BF16) && defined(FLASHINFER_ENABLE_FP8) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - case at::ScalarType::Half: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::BFloat16: { \ - using c_type = nv_bfloat16; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::Float8_e4m3fn: { \ - using c_type = __nv_fp8_e4m3; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::Float8_e5m2: { \ - using c_type = __nv_fp8_e5m2; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() -#elif defined(FLASHINFER_ENABLE_BF16) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - case at::ScalarType::Half: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::BFloat16: { \ - using c_type = nv_bfloat16; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() -#elif defined(FLASHINFER_ENABLE_FP8) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - case at::ScalarType::Half: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::Float8_e4m3fn: { \ - using c_type = __nv_fp8_e4m3; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::Float8_e5m2: { \ - using c_type = __nv_fp8_e5m2; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() -#else -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - case at::ScalarType::Half: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() -#endif - -#define DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_dtype, kv_dtype, c_type_q, c_type_kv, ...) \ - [&]() -> bool { \ - if (kv_dtype == q_dtype) { \ - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_dtype, c_type_q, [&] { \ - using c_type_kv = c_type_q; \ - return __VA_ARGS__(); \ - }); \ - } else { \ - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_dtype, c_type_q, [&] { \ - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_dtype, c_type_kv, \ - [&] { return __VA_ARGS__(); }); \ - }); \ - } \ - }() - -#define _DISPATCH_SWITCH(var_name, cond, ...) \ - [&]() -> bool { \ - switch (cond) { \ - __VA_ARGS__ \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() - -#define _DISPATCH_CASE(case_expr, case_var, ...) \ - case case_expr: { \ - constexpr auto case_var = case_expr; \ - return __VA_ARGS__(); \ - } - -#define DISPATCH_head_dim(expr, const_expr, ...) \ - _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) - -#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH("positional encoding mode", expr, \ - _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) - -#define DISPATCH_allow_fp16_qk_reduction(expr, const_expr, ...) \ - _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ - _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) - -#define DISPATCH_mask_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) - -#define DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ - [&]() -> bool { \ - if (use_logits_soft_cap) { \ - constexpr bool USE_LOGITS_SOFT_CAP = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool USE_LOGITS_SOFT_CAP = false; \ - return __VA_ARGS__(); \ - } \ - }() - -inline void check_shape(const torch::Tensor& a, const torch::Tensor& b, const char* a_name, - const char* b_name) { - TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", - b.dim()); - for (int i = 0; i < a.dim(); ++i) { - TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); - } -} - -inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { - return (uint32_t(a) << 16) | uint32_t(b); -} - -#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \ - TORCH_CHECK(num_qo_heads % num_kv_heads == 0, "num_qo_heads(", num_qo_heads, \ - ") must be divisible by num_kv_heads(", num_kv_heads, ")") - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") - -#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) - -#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) - -#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) - -inline bool is_float8_tensor(const torch::Tensor& tensor) { - return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || - tensor.scalar_type() == at::ScalarType::Float8_e5m2; -} diff --git a/python/flashinfer/activation.py b/python/flashinfer/activation.py index 356ff1ab7..5c78918fa 100644 --- a/python/flashinfer/activation.py +++ b/python/flashinfer/activation.py @@ -18,13 +18,8 @@ import torch -from .jit import ( - FLASHINFER_GEN_SRC_DIR, - gen_act_and_mul_cu, - has_prebuilt_ops, - load_cuda_ops, -) -from .utils import register_custom_op, register_fake_op +from .jit import gen_act_and_mul_module, has_prebuilt_ops, load_cuda_ops +from .utils import get_cuda_stream, register_custom_op, register_fake_op silu_def_cu_str = r""" __device__ __forceinline__ float silu(const float& val) { @@ -54,17 +49,6 @@ } -def compile_act_and_mul_module(name: str, act_func_def: str, verbose: bool = False): - gen_act_and_mul_cu(name, act_func_def) - return load_cuda_ops( - f"{name}_and_mul", - [ - FLASHINFER_GEN_SRC_DIR / f"{name}_and_mul.cu", - ], - verbose=verbose, - ) - - _jit_modules = {} @@ -76,7 +60,7 @@ def get_act_and_mul_module(act_func_name: str): module = _kernels else: - module = compile_act_and_mul_module( + module = gen_act_and_mul_module( act_func_name, act_func_def_str[act_func_name] ) @@ -86,7 +70,8 @@ def get_act_and_mul_module(act_func_name: str): @register_custom_op(f"flashinfer::{fname}", mutates_args=("out",)) def _act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None: - fn(out, input) + with input.device as device: # device guard + fn(out, input, get_cuda_stream(device)) @register_fake_op(f"flashinfer::{fname}") def _fake_act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None: @@ -136,7 +121,10 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - get_act_and_mul_module("silu").silu_and_mul(out, input) + get_act_and_mul_module("silu").silu_and_mul( + out, + input, + ) return out diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index c72cf7b84..38458bb91 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -21,7 +21,7 @@ from .decode import BatchDecodeWithPagedKVCacheWrapper from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache -from .utils import register_custom_op, register_fake_op +from .utils import get_cuda_stream, register_custom_op, register_fake_op _cascade_module = None @@ -93,7 +93,15 @@ def merge_state( >>> s_merged.shape torch.Size([2048, 32]) """ - return get_cascade_module().merge_state(v_a, s_a, v_b, s_b) + with v_a.device as device: # device guard + s_a = s_a.to(torch.float32) + s_b = s_b.to(torch.float32) + v_merged = torch.empty_like(v_a) + s_merged = torch.empty_like(s_a) + get_cascade_module().merge_state( + v_a, s_a, v_b, s_b, v_merged, s_merged, get_cuda_stream(device) + ) + return v_merged, s_merged @register_fake_op("flashinfer::merge_state") @@ -149,7 +157,12 @@ def merge_state_in_place( >>> s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> flashinfer.merge_state_in_place(v, s, v_other, s_other) """ - get_cascade_module().merge_state_in_place(v, s, v_other, s_other, mask) + with v.device as device: # device guard + s = s.to(torch.float32) + s_other = s_other.to(torch.float32) + get_cascade_module().merge_state_in_place( + v, s, v_other, s_other, mask, get_cuda_stream(device) + ) @register_fake_op("flashinfer::merge_state_in_place") @@ -201,16 +214,27 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch. >>> s_merged.shape torch.Size([2048, 32]) """ - return get_cascade_module().merge_states(v, s) + with v.device as device: # device guard + s = s.to(torch.float32) + seq_len, _, num_heads, head_dim = v.size() + v_merged = torch.empty( + seq_len, num_heads, head_dim, dtype=v.dtype, device=device + ) + s_merged = torch.empty(seq_len, num_heads, dtype=torch.float32, device=device) + get_cascade_module().merge_states( + v, s, v_merged, s_merged, get_cuda_stream(device) + ) + return v_merged, s_merged @register_fake_op("flashinfer::merge_states") def _fake_merge_states( v: torch.Tensor, s: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - v = torch.empty_like(v) - s = torch.empty_like(s) - return v, s + seq_len, _, num_heads, head_dim = v.size() + v_merged = torch.empty(seq_len, num_heads, head_dim, dtype=v.dtype) + s_merged = torch.empty(seq_len, num_heads, dtype=torch.float32) + return v_merged, s_merged class MultiLevelCascadeAttentionWrapper: diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index e12c9d062..705d8cf4b 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -22,14 +22,13 @@ import torch from .jit import ( - gen_batch_decode_cu, - gen_batch_decode_mla_cu, - gen_single_decode_cu, + gen_batch_decode_mla_module, + gen_batch_decode_module, + gen_single_decode_module, get_batch_decode_mla_uri, get_batch_decode_uri, get_single_decode_uri, has_prebuilt_ops, - load_cuda_ops, prebuilt_ops_uri, ) from .prefill import get_batch_prefill_module, get_single_prefill_module @@ -45,47 +44,11 @@ _get_range_buf, _unpack_paged_kv_cache, canonicalize_torch_dtype, + get_cuda_stream, register_custom_op, register_fake_op, ) - -def compile_single_decode_module( - *args, - verbose: bool = False, -): - uri, path = gen_single_decode_cu(*args) - return load_cuda_ops( - uri, - [path], - verbose=verbose, - ) - - -def compile_batch_decode_module( - *args, - verbose: bool = False, -): - uri, path = gen_batch_decode_cu(*args) - return load_cuda_ops( - uri, - [path], - verbose=verbose, - ) - - -def compile_batch_decode_mla_module( - *args, - verbose: bool = False, -): - uri, path = gen_batch_decode_mla_cu(*args) - return load_cuda_ops( - uri, - [path], - verbose=verbose, - ) - - _single_decode_modules = {} _batch_decode_modules = {} _batch_decode_mla_modules = {} @@ -100,7 +63,7 @@ def get_single_decode_module(*args): run_func = _kernels.single_decode_with_kv_cache else: - run_func = compile_single_decode_module(*args).run + run_func = gen_single_decode_module(*args).run # torch library for single_decode_with_kv_cache @@ -118,19 +81,25 @@ def run_single_decode( rope_scale: float, rope_theta: float, ) -> torch.Tensor: - return run_func( - q, - k, - v, - tmp, - alibi_slopes, - kv_layout_code, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - ) + with q.device as device: + o = torch.empty_like(q) + run_func( + q, + k, + v, + tmp, + alibi_slopes, + o, + kv_layout_code, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + get_cuda_stream(device), + ) + + return o @register_fake_op(f"flashinfer::{uri}_run") def _fake_run_single_decode( @@ -176,7 +145,7 @@ def get_batch_decode_module(*args): ) run_func = _kernels.batch_decode_with_paged_kv_cache_run else: - mod = compile_batch_decode_module(*args) + mod = gen_batch_decode_module(*args) plan_func = mod.plan run_func = mod.run @@ -211,25 +180,30 @@ def run_batch_decode( rope_theta: float, maybe_lse: Optional[torch.Tensor], ) -> torch.Tensor: - return run_func( - float_workspace_buffer, - int_workspace_buffer, - plan_info_vec, - q, - paged_k_cache, - paged_v_cache, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, - alibi_slopes, - kv_layout_code, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - maybe_lse, - ) + with q.device as device: + o = torch.empty_like(q) + run_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + alibi_slopes, + o, + kv_layout_code, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o @register_fake_op(f"flashinfer::{uri}_run") def _fake_run_batch_decode( @@ -273,16 +247,29 @@ def single_decode_with_kv_cache_with_jit_module( kv_layout: str = "NHD", window_left: int = -1, ): - tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) - return jit_module.run( - q, k, v, tmp, TensorLayout[kv_layout].value, window_left, *args - ) + with q.device as device: + tmp = _get_cache_buf( + "single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, device + ) + o = torch.empty_like(q) + jit_module.run( + q, + k, + v, + tmp, + o, + TensorLayout[kv_layout].value, + window_left, + *args, + get_cuda_stream(device), + ) + return o def get_batch_decode_mla_module(*args): global _batch_decode_mla_modules if args not in _batch_decode_mla_modules: - _batch_decode_mla_modules[args] = compile_batch_decode_mla_module(*args) + _batch_decode_mla_modules[args] = gen_batch_decode_mla_module(*args) return _batch_decode_mla_modules[args] @@ -767,18 +754,20 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap False, # allow_fp16_qk_reduction ) - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - ) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_host, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) else: self._cached_module = get_batch_decode_module( q_data_type, @@ -790,17 +779,19 @@ def plan( window_left != -1, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap ) - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - indptr_host, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - ) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + indptr_host, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left @@ -1288,16 +1279,18 @@ def plan( window_left != -1, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap ) - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - indptr, - batch_size, - num_qo_heads, - page_size, - self.is_cuda_graph_enabled, - ) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + indptr, + batch_size, + num_qo_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) self._sm_scale = sm_scale self._window_left = window_left diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index f3be09284..046fe36f1 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -23,7 +23,9 @@ from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops from .utils import ( + _get_cache_buf, get_compute_capability, + get_cuda_stream, get_indptr, register_custom_op, register_fake_op, @@ -52,18 +54,33 @@ def get_gemm_module(): # torch library for bmm_fp8 - @register_custom_op("flashinfer::bmm_fp8", mutates_args=("D",)) + @register_custom_op( + "flashinfer::bmm_fp8", mutates_args=("workspace_buffer", "D") + ) def bmm_fp8( + workspace_buffer: torch.Tensor, A: torch.Tensor, B: torch.Tensor, D: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, ) -> None: - module.bmm_fp8(A, B, D, A_scale, B_scale) + with A.device as device: + cublas_handle = torch.cuda.current_blas_handle() + module.bmm_fp8( + A, + B, + D, + A_scale, + B_scale, + workspace_buffer, + cublas_handle, + get_cuda_stream(device), + ) @register_fake_op("flashinfer::bmm_fp8") def _fake_bmm_fp8( + workspace_buffer: torch.Tensor, A: torch.Tensor, B: torch.Tensor, D: torch.Tensor, @@ -88,18 +105,20 @@ def cutlass_segment_gemm( empty_x_data: torch.Tensor, weight_column_major: bool, ) -> None: - module.cutlass_segment_gemm( - workspace_buffer, - all_problems, - x_data, - w_data, - y_data, - x_ld, - w_ld, - y_ld, - empty_x_data, - weight_column_major, - ) + with x_data.device as device: + module.cutlass_segment_gemm( + workspace_buffer, + all_problems, + x_data, + w_data, + y_data, + x_ld, + w_ld, + y_ld, + empty_x_data, + weight_column_major, + get_cuda_stream(device), + ) @register_fake_op("flashinfer::cutlass_segment_gemm") def _fake_cutlass_segment_gemm( @@ -145,7 +164,10 @@ def get_gemm_sm90_module(): # torch library for cutlass_segment_gemm_sm90 - @register_custom_op("flashinfer::cutlass_segment_gemm_sm90", mutates_args=("y")) + @register_custom_op( + "flashinfer::cutlass_segment_gemm_sm90", + mutates_args=("workspace_buffer", "y"), + ) def cutlass_segment_gemm_sm90( workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, @@ -160,19 +182,21 @@ def cutlass_segment_gemm_sm90( empty_x_data: torch.Tensor, weight_column_major: bool, ) -> None: - module.cutlass_segment_gemm_sm90( - workspace_buffer, - int_workspace_buffer, - all_problems, - x_data, - w_data, - y_data, - x_stride, - w_stride, - y_stride, - empty_x_data, - weight_column_major, - ) + with x_data.device as device: + module.cutlass_segment_gemm_sm90( + workspace_buffer, + int_workspace_buffer, + all_problems, + x_data, + w_data, + y_data, + x_stride, + w_stride, + y_stride, + empty_x_data, + weight_column_major, + get_cuda_stream(device), + ) @register_fake_op("flashinfer::cutlass_segment_gemm_sm90") def _fake_cutlass_segment_gemm_sm90( @@ -699,5 +723,6 @@ def bmm_fp8( device=A.device, dtype=dtype, ) - get_gemm_module().bmm_fp8(A, B, out, A_scale, B_scale) + workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) + get_gemm_module().bmm_fp8(workspace_buffer, A, B, out, A_scale, B_scale) return out diff --git a/python/flashinfer/jit/__init__.py b/python/flashinfer/jit/__init__.py index b30f9a244..ee95117c3 100644 --- a/python/flashinfer/jit/__init__.py +++ b/python/flashinfer/jit/__init__.py @@ -14,34 +14,27 @@ limitations under the License. """ -import logging -import os -import re -from pathlib import Path -from typing import List, Union - -import torch.utils.cpp_extension as torch_cpp_ext -from filelock import FileLock - # Re-export -from .activation import gen_act_and_mul_cu as gen_act_and_mul_cu +from .activation import gen_act_and_mul_module as gen_act_and_mul_module from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str -from .attention import gen_batch_decode_cu as gen_batch_decode_cu -from .attention import gen_batch_decode_mla_cu as gen_batch_decode_mla_cu -from .attention import gen_batch_prefill_cu as gen_batch_prefill_cu -from .attention import gen_single_decode_cu as gen_single_decode_cu -from .attention import gen_single_prefill_cu as gen_single_prefill_cu +from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module +from .attention import gen_batch_decode_module as gen_batch_decode_module +from .attention import gen_batch_prefill_module as gen_batch_prefill_module +from .attention import ( + gen_customize_single_decode_module as gen_customize_single_decode_module, +) +from .attention import ( + gen_customize_single_prefill_module as gen_customize_single_prefill_module, +) +from .attention import gen_single_decode_module as gen_single_decode_module +from .attention import gen_single_prefill_module as gen_single_prefill_module from .attention import get_batch_decode_mla_uri as get_batch_decode_mla_uri from .attention import get_batch_decode_uri as get_batch_decode_uri from .attention import get_batch_prefill_uri as get_batch_prefill_uri from .attention import get_single_decode_uri as get_single_decode_uri from .attention import get_single_prefill_uri as get_single_prefill_uri -from .env import CUTLASS_INCLUDE_DIRS as CUTLASS_INCLUDE_DIRS -from .env import FLASHINFER_CSRC_DIR as FLASHINFER_CSRC_DIR -from .env import FLASHINFER_GEN_SRC_DIR as FLASHINFER_GEN_SRC_DIR -from .env import FLASHINFER_INCLUDE_DIR as FLASHINFER_INCLUDE_DIR -from .env import FLASHINFER_JIT_DIR as FLASHINFER_JIT_DIR -from .env import FLASHINFER_WORKSPACE_DIR as FLASHINFER_WORKSPACE_DIR +from .core import clear_cache_dir, load_cuda_ops +from .env import * try: from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri # type: ignore[import] @@ -50,109 +43,3 @@ except ImportError: prebuilt_ops_uri = set() has_prebuilt_ops = False - -if not os.path.exists(FLASHINFER_WORKSPACE_DIR): - os.makedirs(FLASHINFER_WORKSPACE_DIR) - - -class FlashInferJITLogger(logging.Logger): - def __init__(self, name): - super().__init__(name) - self.setLevel(logging.INFO) - self.addHandler(logging.StreamHandler()) - log_path = FLASHINFER_WORKSPACE_DIR / "flashinfer_jit.log" - if not os.path.exists(log_path): - # create an empty file - with open(log_path, "w") as f: - pass - self.addHandler(logging.FileHandler(log_path)) - # set the format of the log - self.handlers[0].setFormatter( - logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - ) - self.handlers[1].setFormatter( - logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - ) - - def info(self, msg): - super().info("flashinfer.jit: " + msg) - - -logger = FlashInferJITLogger("flashinfer.jit") - - -def check_cuda_arch(): - # cuda arch check for fp8 at the moment. - for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): - arch = int(re.search(r"compute_(\d+)", cuda_arch_flags).group(1)) - if arch < 75: - raise RuntimeError("FlashInfer requires sm75+") - - -def clear_cache_dir(): - if os.path.exists(FLASHINFER_JIT_DIR): - for file in os.listdir(FLASHINFER_JIT_DIR): - os.remove(os.path.join(FLASHINFER_JIT_DIR, file)) - - -def remove_unwanted_pytorch_nvcc_flags(): - REMOVE_NVCC_FLAGS = [ - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ] - for flag in REMOVE_NVCC_FLAGS: - try: - torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) - except ValueError: - pass - - -remove_unwanted_pytorch_nvcc_flags() - - -def load_cuda_ops( - name: str, - sources: List[Union[str, Path]], - extra_cflags: List[str] = [], - extra_cuda_cflags: List[str] = [], - extra_ldflags=None, - extra_include_paths=None, - verbose=False, -): - cflags = ["-O3", "-Wno-switch-bool"] - cuda_cflags = [ - "-O3", - "-std=c++17", - "--threads", - "4", - "-use_fast_math", - "-DFLASHINFER_ENABLE_BF16", - "-DFLASHINFER_ENABLE_FP8", - ] - cflags += extra_cflags - cuda_cflags += extra_cuda_cflags - logger.info(f"Loading JIT ops: {name}") - check_cuda_arch() - build_directory = FLASHINFER_JIT_DIR / name - if not os.path.exists(build_directory): - os.makedirs(build_directory, exist_ok=True) - if extra_include_paths is None: - extra_include_paths = [ - FLASHINFER_INCLUDE_DIR, - FLASHINFER_CSRC_DIR, - ] + CUTLASS_INCLUDE_DIRS - lock = FileLock(FLASHINFER_JIT_DIR / f"{name}.lock", thread_local=False) - with lock: - return torch_cpp_ext.load( - name, - list(map(lambda _: str(_), sources)), - extra_cflags=cflags, - extra_cuda_cflags=cuda_cflags, - extra_ldflags=extra_ldflags, - extra_include_paths=list(map(lambda _: str(_), extra_include_paths)), - build_directory=build_directory, - verbose=verbose, - with_cuda=True, - ) diff --git a/python/flashinfer/jit/activation.py b/python/flashinfer/jit/activation.py index e4d756cdb..1e6a655f5 100644 --- a/python/flashinfer/jit/activation.py +++ b/python/flashinfer/jit/activation.py @@ -18,14 +18,12 @@ import jinja2 +from .core import load_cuda_ops from .env import FLASHINFER_GEN_SRC_DIR from .utils import write_if_different activation_templ = r""" -#include -#include #include -#include #include "pytorch_extension_utils.h" {% set func_name = act_func_name ~ '_and_mul' %} @@ -34,13 +32,12 @@ {{ act_func_def }} -void {{ func_name }}(torch::Tensor& out, torch::Tensor& input) { +void {{ func_name }}(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); @@ -63,11 +60,16 @@ def get_act_and_mul_cu_str(act_func_name: str, act_func_def: str) -> str: return template.render(act_func_name=act_func_name, act_func_def=act_func_def) -def gen_act_and_mul_cu(act_func_name: str, act_func_def: str) -> None: +def gen_act_and_mul_module(act_func_name: str, act_func_def: str) -> None: gen_directory = FLASHINFER_GEN_SRC_DIR if not os.path.exists(gen_directory): os.makedirs(gen_directory) + sources = [gen_directory / f"{act_func_name}_and_mul.cu"] write_if_different( - gen_directory / f"{act_func_name}_and_mul.cu", + sources[0], get_act_and_mul_cu_str(act_func_name, act_func_def), ) + return load_cuda_ops( + f"{act_func_name}_and_mul", + sources, + ) diff --git a/python/flashinfer/jit/attention.py b/python/flashinfer/jit/attention.py index c6e4b7881..3d2402b34 100644 --- a/python/flashinfer/jit/attention.py +++ b/python/flashinfer/jit/attention.py @@ -21,13 +21,19 @@ import jinja2 import torch -from .batch_decode_mla_templ import batch_decode_mla_templ -from .batch_decode_templ import batch_decode_templ -from .batch_prefill_templ import batch_prefill_templ +from .batch_decode_mla_templ import batch_decode_mla_suffix, batch_decode_mla_templ +from .batch_decode_templ import batch_decode_suffix, batch_decode_templ +from .batch_prefill_templ import batch_prefill_suffix, batch_prefill_templ +from .core import load_cuda_ops from .env import FLASHINFER_GEN_SRC_DIR -from .single_decode_templ import customizable_single_decode_templ, single_decode_templ +from .single_decode_templ import ( + customizable_single_decode_templ, + single_decode_suffix, + single_decode_templ, +) from .single_prefill_templ import ( customizable_single_prefill_templ, + single_prefill_suffix, single_prefill_templ, ) from .utils import ( @@ -38,7 +44,13 @@ ) -def get_single_decode_cu_str( +def render_templates(template_strs: List[str], context: dict) -> List[str]: + return [ + template.render(**context) for template in map(jinja2.Template, template_strs) + ] + + +def get_single_decode_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, @@ -46,16 +58,18 @@ def get_single_decode_cu_str( pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, -) -> str: - template = jinja2.Template(single_decode_templ) - return template.render( - dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], - dtype_o=dtype_map[dtype_o], - head_dim=head_dim, - pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], - use_sliding_window="true" if use_sliding_window else "false", - use_logits_soft_cap="true" if use_logits_soft_cap else "false", +) -> List[str]: + return render_templates( + single_decode_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "head_dim": head_dim, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + "use_sliding_window": "true" if use_sliding_window else "false", + "use_logits_soft_cap": "true" if use_logits_soft_cap else "false", + }, ) @@ -79,21 +93,20 @@ def get_single_decode_uri( ) -def gen_single_decode_cu(*args) -> Tuple[str, pathlib.Path]: +def gen_single_decode_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR - if not os.path.exists(gen_directory): - os.makedirs(gen_directory) + os.makedirs(gen_directory, exist_ok=True) uri = get_single_decode_uri(*args) - file_name = f"{uri}.cu" - path = gen_directory / file_name - write_if_different( - path, - get_single_decode_cu_str(*args), - ) - return uri, path + sources = get_single_decode_sources(*args) + source_paths = [] + for suffix, source in zip(single_decode_suffix, sources): + path = gen_directory / f"{uri}{suffix}" + source_paths.append(path) + write_if_different(path, source) + return load_cuda_ops(uri, source_paths) -def get_batch_decode_cu_str( +def get_batch_decode_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, @@ -102,17 +115,19 @@ def get_batch_decode_cu_str( pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, -) -> str: - template = jinja2.Template(batch_decode_templ) - return template.render( - dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], - dtype_o=dtype_map[dtype_o], - dtype_idx=dtype_map[dtype_idx], - head_dim=head_dim, - pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], - use_sliding_window="true" if use_sliding_window else "false", - use_logits_soft_cap="true" if use_logits_soft_cap else "false", +) -> List[str]: + return render_templates( + batch_decode_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "dtype_idx": dtype_map[dtype_idx], + "head_dim": head_dim, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + "use_sliding_window": "true" if use_sliding_window else "false", + "use_logits_soft_cap": "true" if use_logits_soft_cap else "false", + }, ) @@ -138,21 +153,21 @@ def get_batch_decode_uri( ) -def gen_batch_decode_cu(*args) -> Tuple[str, pathlib.Path]: +def gen_batch_decode_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR if not os.path.exists(gen_directory): os.makedirs(gen_directory) uri = get_batch_decode_uri(*args) - file_name = f"{uri}.cu" - path = gen_directory / file_name - write_if_different( - path, - get_batch_decode_cu_str(*args), - ) - return uri, path + sources = get_batch_decode_sources(*args) + source_paths = [] + for suffix, source in zip(batch_decode_suffix, sources): + path = gen_directory / f"{uri}{suffix}" + source_paths.append(path) + write_if_different(path, source) + return load_cuda_ops(uri, source_paths) -def get_batch_decode_mla_cu_str( +def get_batch_decode_mla_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, @@ -160,18 +175,18 @@ def get_batch_decode_mla_cu_str( head_dim: int, use_sliding_window: bool, use_logits_soft_cap: bool, -) -> str: - template = jinja2.Template(batch_decode_mla_templ) - return template.render( - dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], - dtype_o=dtype_map[dtype_o], - dtype_idx=dtype_map[dtype_idx], - head_dim_ckv=head_dim, - head_dim_kpe=head_dim - // 8, # fixme: head_dim_ckv(kv_lora_rank) is 8 times the size of head_dim_kpe(qk_rope_head_dim) for all MLA model (DeepSeek-V2-Lite, DeepSeek-V2.5, MiniCPM3) at the time Oct.2024 - use_sliding_window="true" if use_sliding_window else "false", - use_logits_soft_cap="true" if use_logits_soft_cap else "false", +) -> List[str]: + return render_templates( + batch_decode_mla_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "dtype_idx": dtype_map[dtype_idx], + "head_dim": head_dim, + "use_sliding_window": "true" if use_sliding_window else "false", + "use_logits_soft_cap": "true" if use_logits_soft_cap else "false", + }, ) @@ -195,21 +210,21 @@ def get_batch_decode_mla_uri( ) -def gen_batch_decode_mla_cu(*args) -> None: +def gen_batch_decode_mla_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR if not os.path.exists(gen_directory): os.makedirs(gen_directory) uri = get_batch_decode_mla_uri(*args) - file_name = f"{uri}.cu" - path = gen_directory / file_name - write_if_different( - path, - get_batch_decode_mla_cu_str(*args), - ) - return uri, path + sources = get_batch_decode_mla_sources(*args) + source_paths = [] + for suffix, source in zip(batch_decode_mla_suffix, sources): + path = gen_directory / f"{uri}{suffix}" + source_paths.append(path) + write_if_different(path, source) + return load_cuda_ops(uri, source_paths) -def get_single_prefill_cu_str( +def get_single_prefill_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, @@ -218,17 +233,19 @@ def get_single_prefill_cu_str( use_sliding_window: bool, use_logits_soft_cap: bool, use_fp16_qk_reduction: bool, -) -> str: - template = jinja2.Template(single_prefill_templ) - return template.render( - dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], - dtype_o=dtype_map[dtype_o], - head_dim=head_dim, - pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], - use_sliding_window="true" if use_sliding_window else "false", - use_logits_soft_cap="true" if use_logits_soft_cap else "false", - use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false", +) -> List[str]: + return render_templates( + single_prefill_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "head_dim": head_dim, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + "use_sliding_window": "true" if use_sliding_window else "false", + "use_logits_soft_cap": "true" if use_logits_soft_cap else "false", + "use_fp16_qk_reduction": "true" if use_fp16_qk_reduction else "false", + }, ) @@ -254,21 +271,22 @@ def get_single_prefill_uri( ) -def gen_single_prefill_cu(*args) -> Tuple[str, pathlib.Path]: +def gen_single_prefill_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR if not os.path.exists(gen_directory): os.makedirs(gen_directory) uri = get_single_prefill_uri(*args) - file_name = f"{uri}.cu" - path = gen_directory / file_name - write_if_different( - path, - get_single_prefill_cu_str(*args), - ) - return uri, path + sources = get_single_prefill_sources(*args) + source_paths = [] + for suffix, source in zip(single_prefill_suffix, sources): + path = gen_directory / f"{uri}{suffix}" + source_paths.append(path) + write_if_different(path, source) + + return load_cuda_ops(uri, source_paths) -def get_batch_prefill_cu_str( +def get_batch_prefill_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, @@ -278,18 +296,20 @@ def get_batch_prefill_cu_str( use_sliding_window: bool, use_logits_soft_cap: bool, use_fp16_qk_reduction: bool, -) -> str: - template = jinja2.Template(batch_prefill_templ) - return template.render( - dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], - dtype_o=dtype_map[dtype_o], - dtype_idx=dtype_map[dtype_idx], - head_dim=head_dim, - pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], - use_sliding_window="true" if use_sliding_window else "false", - use_logits_soft_cap="true" if use_logits_soft_cap else "false", - use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false", +) -> List[str]: + return render_templates( + batch_prefill_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "dtype_idx": dtype_map[dtype_idx], + "head_dim": head_dim, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + "use_sliding_window": "true" if use_sliding_window else "false", + "use_logits_soft_cap": "true" if use_logits_soft_cap else "false", + "use_fp16_qk_reduction": "true" if use_fp16_qk_reduction else "false", + }, ) @@ -317,21 +337,22 @@ def get_batch_prefill_uri( ) -def gen_batch_prefill_cu(*args) -> Tuple[str, pathlib.Path]: +def gen_batch_prefill_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR if not os.path.exists(gen_directory): os.makedirs(gen_directory) uri = get_batch_prefill_uri(*args) - file_name = f"{uri}.cu" - path = gen_directory / file_name - write_if_different( - path, - get_batch_prefill_cu_str(*args), - ) - return uri, path + sources = get_batch_prefill_sources(*args) + source_paths = [] + for suffix, source in zip(batch_prefill_suffix, sources): + path = gen_directory / f"{uri}{suffix}" + source_paths.append(path) + write_if_different(path, source) + + return load_cuda_ops(uri, source_paths) -def get_customize_single_decode_cu_str( +def get_customize_single_decode_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, @@ -342,8 +363,7 @@ def get_customize_single_decode_cu_str( additional_input_scalar_var_types: List[str], variant_name: str, variant_decl: str, -) -> str: - template = jinja2.Template(customizable_single_decode_templ) +) -> List[str]: additional_params_decl = "".join( [ f"{dtype}* {var};\n" @@ -377,7 +397,7 @@ def get_customize_single_decode_cu_str( + [f", {var}({var})" for var in additional_input_scalar_var_names] ) additional_func_params = "".join( - [f", torch::Tensor {var}" for var in additional_input_tensor_var_names] + [f", at::Tensor {var}" for var in additional_input_tensor_var_names] + [ f", {dtype} {var}" for dtype, var in zip( @@ -395,22 +415,25 @@ def get_customize_single_decode_cu_str( + [f", {var}" for var in additional_input_scalar_var_names] ) - return template.render( - dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], - dtype_o=dtype_map[dtype_o], - head_dim=head_dim, - additional_params_decl=additional_params_decl, - additional_params=additional_params, - additional_params_init=additional_params_init, - variant_decl=variant_decl, - variant_name=variant_name, - additional_func_params=additional_func_params, - additional_params_data=additional_params_data, + return render_templates( + customizable_single_decode_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "head_dim": head_dim, + "additional_params_decl": additional_params_decl, + "additional_params": additional_params, + "additional_params_init": additional_params_init, + "variant_decl": variant_decl, + "variant_name": variant_name, + "additional_func_params": additional_func_params, + "additional_params_data": additional_params_data, + }, ) -def get_customize_single_prefill_cu_str( +def get_customize_single_prefill_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, @@ -421,8 +444,7 @@ def get_customize_single_prefill_cu_str( additional_input_scalar_var_types: List[str], variant_name: str, variant_decl: str, -) -> str: - template = jinja2.Template(customizable_single_prefill_templ) +) -> List[str]: additional_params_decl = "".join( [ f"{dtype}* {var};\n" @@ -456,7 +478,7 @@ def get_customize_single_prefill_cu_str( + [f", {var}({var})" for var in additional_input_scalar_var_names] ) additional_func_params = "".join( - [f", torch::Tensor {var}" for var in additional_input_tensor_var_names] + [f", at::Tensor {var}" for var in additional_input_tensor_var_names] + [ f", {dtype} {var}" for dtype, var in zip( @@ -474,16 +496,47 @@ def get_customize_single_prefill_cu_str( + [f", {var}" for var in additional_input_scalar_var_names] ) - return template.render( - dtype_q=dtype_map[dtype_q], - dtype_kv=dtype_map[dtype_kv], - dtype_o=dtype_map[dtype_o], - head_dim=head_dim, - additional_params_decl=additional_params_decl, - additional_params=additional_params, - additional_params_init=additional_params_init, - variant_decl=variant_decl, - variant_name=variant_name, - additional_func_params=additional_func_params, - additional_params_data=additional_params_data, + return render_templates( + customizable_single_prefill_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "head_dim": head_dim, + "additional_params_decl": additional_params_decl, + "additional_params": additional_params, + "additional_params_init": additional_params_init, + "variant_decl": variant_decl, + "variant_name": variant_name, + "additional_func_params": additional_func_params, + "additional_params_data": additional_params_data, + }, ) + + +def gen_customize_single_decode_module(module_name, *args): + gen_directory = FLASHINFER_GEN_SRC_DIR + if not os.path.exists(gen_directory): + os.makedirs(gen_directory) + sources = get_customize_single_decode_sources(*args) + source_paths = [] + for suffix, source in zip(single_decode_suffix, sources): + path = gen_directory / f"{module_name}{suffix}" + source_paths.append(path) + write_if_different(path, source) + + return load_cuda_ops(module_name, source_paths) + + +def gen_customize_single_prefill_module(module_name, *args): + gen_directory = FLASHINFER_GEN_SRC_DIR + if not os.path.exists(gen_directory): + os.makedirs(gen_directory) + sources = get_customize_single_prefill_sources(*args) + source_paths = [] + for suffix, source in zip(single_prefill_suffix, sources): + path = gen_directory / f"{module_name}{suffix}" + source_paths.append(path) + write_if_different(path, source) + + return load_cuda_ops(module_name, source_paths) diff --git a/python/flashinfer/jit/batch_decode_mla_templ.py b/python/flashinfer/jit/batch_decode_mla_templ.py index 1dd83e9bb..bf553e04c 100644 --- a/python/flashinfer/jit/batch_decode_mla_templ.py +++ b/python/flashinfer/jit/batch_decode_mla_templ.py @@ -14,9 +14,14 @@ limitations under the License. """ -batch_decode_mla_templ = r""" -#include -#include +batch_decode_mla_suffix = [ + "_plan.cu", + "_run.cu", + "_pybind.cc", +] + +batch_decode_mla_templ = [ + r"""#include #include #include #include @@ -29,22 +34,20 @@ using AttentionVariant = ComposedAttention; std::vector BatchDecodeWithPagedKVCachePlanMLA( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor page_locked_int_workspace_buffer, - torch::Tensor indptr, + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int page_size, - bool enable_cuda_graph) { + bool enable_cuda_graph, + int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - auto device = float_workspace_buffer.device(); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - indptr = indptr.to(torch::kCPU); DecodePlanInfo plan_info; + cudaStream_t stream = reinterpret_cast(cuda_stream); auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA<{{ head_dim_ckv }}, {{ head_dim_kpe }}, AttentionVariant>; @@ -56,7 +59,7 @@ int_workspace_size_in_bytes, plan_info, static_cast<{{ dtype_idx }}*>(indptr.data_ptr()), - batch_size, num_qo_heads, page_size, enable_cuda_graph, /*stream=*/torch_current_stream, + batch_size, num_qo_heads, page_size, enable_cuda_graph, /*stream=*/stream, work_estimation_func); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCachePlanMLA failed with error ", @@ -64,20 +67,35 @@ return plan_info.ToVector(); } +""" + r""" +#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +using ParamsT = BatchDecodeParamsMLA<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; +using AttentionVariant = ComposedAttention; -std::vector BatchDecodeWithPagedKVCacheRunMLA( - torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, +void BatchDecodeWithPagedKVCacheRunMLA( + at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, std::vector plan_info_vec, - torch::Tensor q_nope, - torch::Tensor q_pe, - torch::Tensor paged_ckv_cache, - torch::Tensor paged_kpe_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, + at::Tensor q_nope, + at::Tensor q_pe, + at::Tensor paged_ckv_cache, + at::Tensor paged_kpe_cache, + at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, + at::Tensor o, float sm_scale, int window_left, - float logits_soft_cap, float rope_scale, float rope_theta, bool return_lse) { + float logits_soft_cap, float rope_scale, float rope_theta, std::optional maybe_lse, + int64_t cuda_stream) { DecodePlanInfo plan_info; plan_info.FromVector(plan_info_vec); @@ -86,12 +104,10 @@ int64_t num_qo_heads = q_nope.size(1); int64_t page_size = paged_ckv_cache.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - torch::Tensor o = torch::empty_like(q_nope); - torch::Tensor lse; - if (return_lse) { - lse = torch::empty({batch_size, num_qo_heads}, q_nope.options().dtype((torch::kFloat32))); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1)); } TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); @@ -112,7 +128,7 @@ ParamsT params( static_cast<{{ dtype_q }}*>(q_nope.data_ptr()), static_cast<{{ dtype_q }}*>(q_pe.data_ptr()), /*q_offset=*/nullptr, paged_kv, static_cast<{{ dtype_o }}*>(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); {{ dtype_o }}* tmp_v = nullptr; @@ -135,16 +151,38 @@ params, tmp_v, tmp_s, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", cudaGetErrorString(status)); - - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } } +""", + r"""#include "pytorch_extension_utils.h" + +std::vector BatchDecodeWithPagedKVCachePlanMLA( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor indptr, + unsigned int batch_size, unsigned int num_qo_heads, + unsigned int page_size, + bool enable_cuda_graph, + int64_t cuda_stream); + +void BatchDecodeWithPagedKVCacheRunMLA( + at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + std::vector plan_info_vec, + at::Tensor q_nope, + at::Tensor q_pe, + at::Tensor paged_ckv_cache, + at::Tensor paged_kpe_cache, + at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, + at::Tensor o, + float sm_scale, + int window_left, + float logits_soft_cap, float rope_scale, float rope_theta, std::optional maybe_lse, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("plan", &BatchDecodeWithPagedKVCachePlanMLA); m.def("run", &BatchDecodeWithPagedKVCacheRunMLA); } -""" +""", +] diff --git a/python/flashinfer/jit/batch_decode_templ.py b/python/flashinfer/jit/batch_decode_templ.py index 5898b75c8..59700e6f3 100644 --- a/python/flashinfer/jit/batch_decode_templ.py +++ b/python/flashinfer/jit/batch_decode_templ.py @@ -14,11 +14,15 @@ limitations under the License. """ -batch_decode_templ = r""" -#include -#include +batch_decode_suffix = [ + "_plan.cu", + "_run.cu", + "_pybind.cc", +] + +batch_decode_templ = [ + r"""#include #include -#include #include #include #include "pytorch_extension_utils.h" @@ -30,21 +34,18 @@ using AttentionVariant = ComposedAttention; std::vector BatchDecodeWithPagedKVCachePlan( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor page_locked_int_workspace_buffer, - torch::Tensor indptr, + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph) { + bool enable_cuda_graph, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - auto device = float_workspace_buffer.device(); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU"); + cudaStream_t stream = reinterpret_cast(cuda_stream); DecodePlanInfo plan_info; DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { auto work_estimation_func = @@ -58,7 +59,7 @@ int_workspace_size_in_bytes, plan_info, static_cast<{{ dtype_idx }}*>(indptr.data_ptr()), - batch_size, num_qo_heads, /*num_kv_heads,*/ page_size, enable_cuda_graph, /*stream=*/torch_current_stream, + batch_size, num_qo_heads, /*num_kv_heads,*/ page_size, enable_cuda_graph, stream, work_estimation_func); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", @@ -66,20 +67,36 @@ }); return plan_info.ToVector(); } +""", + r""" +#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; -torch::Tensor BatchDecodeWithPagedKVCacheRun( - torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +using ParamsT = BatchDecodeParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; +using AttentionVariant = ComposedAttention; + +void BatchDecodeWithPagedKVCacheRun( + at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, std::vector plan_info_vec, - torch::Tensor q, - torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, - std::optional alibi_slopes, + at::Tensor q, + at::Tensor paged_k_cache, + at::Tensor paged_v_cache, + at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, + std::optional alibi_slopes, + at::Tensor o, unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse) { + std::optional maybe_lse, + int64_t cuda_stream) { DecodePlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(kv_layout_code); @@ -94,15 +111,11 @@ page_size = paged_k_cache.size(1); num_kv_heads = paged_k_cache.size(2); } - - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - torch::Tensor o = torch::empty_like(q); + if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); - TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); @@ -150,17 +163,43 @@ } params.padded_batch_size = plan_info.padded_batch_size; + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = BatchDecodeWithPagedKVCacheDispatched< {{ head_dim }}, {{ pos_encoding_mode }}, AttentionVariant>( - params, tmp_v, tmp_s, /*stream=*/torch_current_stream); + params, tmp_v, tmp_s, stream); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", cudaGetErrorString(status)); - - return o; } +""", + r"""#include "pytorch_extension_utils.h" + +std::vector BatchDecodeWithPagedKVCachePlan( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor indptr, + unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream); + +void BatchDecodeWithPagedKVCacheRun( + at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + std::vector plan_info_vec, + at::Tensor q, + at::Tensor paged_k_cache, + at::Tensor paged_v_cache, + at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, + std::optional alibi_slopes, + at::Tensor o, + unsigned int kv_layout_code, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("plan", &BatchDecodeWithPagedKVCachePlan); m.def("run", &BatchDecodeWithPagedKVCacheRun); } -""" +""", +] diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py index 18b7f6939..8e00ebf3c 100644 --- a/python/flashinfer/jit/batch_prefill_templ.py +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -14,42 +14,35 @@ limitations under the License. """ -batch_prefill_templ = r""" -#include -#include -#include -#include -#include -#include +batch_prefill_suffix = [ + "_plan.cu", + "_ragged_run.cu", + "_paged_run.cu", + "_pybind.cc", +] + +batch_prefill_templ = [ + r"""#include #include "pytorch_extension_utils.h" using namespace flashinfer; -{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} -using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; -using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; - std::vector BatchPrefillWithKVCachePlan( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor page_locked_int_workspace_buffer, - torch::Tensor qo_indptr, - torch::Tensor kv_indptr, + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, + at::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph) { + bool enable_cuda_graph, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - auto device = float_workspace_buffer.device(); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU"); - TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU"); - + cudaStream_t stream = reinterpret_cast(cuda_stream); PrefillPlanInfo plan_info; cudaError_t status = PrefillPlan<{{ dtype_idx }}>( @@ -58,24 +51,39 @@ int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<{{ dtype_idx }}>(), kv_indptr.data_ptr<{{ dtype_idx }}>(), batch_size, num_qo_heads, num_kv_heads, {{ head_dim }}, page_size, enable_cuda_graph, - sizeof({{ dtype_o }}), torch_current_stream); + sizeof({{ dtype_o }}), stream); TORCH_CHECK(status == cudaSuccess, "Failed to plan prefill with error: ", cudaGetErrorString(status)); return plan_info.ToVector(); } +""", + r""" +#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; -torch::Tensor BatchPrefillWithRaggedKVCacheRun( +void BatchPrefillWithRaggedKVCacheRun( unsigned int mask_mode_code, - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, std::vector plan_info_vec, - torch::Tensor q, torch::Tensor k, torch::Tensor v, - std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, - torch::Tensor qo_indptr, torch::Tensor kv_indptr, - std::optional maybe_qk_indptr, + at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, + std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse) { + float rope_scale, float rope_theta, std::optional maybe_lse, + int64_t cuda_stream) { PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); @@ -92,15 +100,10 @@ kv_stride_n = k.stride(1); } - auto device = float_workspace_buffer.device(); - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q, q.options()); if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); - TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } void* float_buffer_ptr = float_workspace_buffer.data_ptr(); @@ -142,6 +145,7 @@ cudaError_t status = cudaSuccess; MaskMode mask_mode = static_cast(mask_mode_code); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; @@ -149,31 +153,43 @@ DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { status = BatchPrefillWithRaggedKVCacheDispatched< CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, RaggedAttentionVariant>( - params, tmp_v, tmp_s, torch_current_stream); + params, tmp_v, tmp_s, stream); }); }); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status)); - - return o; } +""", + r"""#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; -torch::Tensor BatchPrefillWithPagedKVCacheRun( +void BatchPrefillWithPagedKVCacheRun( unsigned int mask_mode_code, - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, std::vector plan_info_vec, - torch::Tensor q, - torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, - std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, - torch::Tensor qo_indptr, - torch::Tensor paged_kv_indptr, - torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, - std::optional maybe_qk_indptr, + at::Tensor q, + at::Tensor paged_k_cache, + at::Tensor paged_v_cache, + std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, + at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse) { + float rope_scale, float rope_theta, std::optional maybe_lse, + int64_t cuda_stream) { PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); @@ -189,14 +205,10 @@ num_kv_heads = paged_k_cache.size(2); } - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q, q.options()); if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); - TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); @@ -254,6 +266,7 @@ cudaError_t status = cudaSuccess; MaskMode mask_mode = static_cast(mask_mode_code); + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; @@ -261,18 +274,62 @@ DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { status = BatchPrefillWithPagedKVCacheDispatched< CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, PagedAttentionVariant>( - params, tmp_v, tmp_s, torch_current_stream); + params, tmp_v, tmp_s, stream); }); }); - TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); - - return o; } +""", + r"""#include "pytorch_extension_utils.h" + +std::vector BatchPrefillWithKVCachePlan( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, + at::Tensor kv_indptr, + unsigned int batch_size, + unsigned int num_qo_heads, + unsigned int num_kv_heads, + unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream); + +void BatchPrefillWithRaggedKVCacheRun( + unsigned int mask_mode_code, + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, + at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, + std::optional maybe_qk_indptr, + at::Tensor o, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, + int64_t cuda_stream); + +void BatchPrefillWithPagedKVCacheRun( + unsigned int mask_mode_code, + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, + at::Tensor q, + at::Tensor paged_k_cache, + at::Tensor paged_v_cache, + std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, + at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, + at::Tensor o, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("plan", &BatchPrefillWithKVCachePlan); m.def("ragged_run", &BatchPrefillWithRaggedKVCacheRun); m.def("paged_run", &BatchPrefillWithPagedKVCacheRun); } -""" +""", +] diff --git a/python/flashinfer/jit/core.py b/python/flashinfer/jit/core.py new file mode 100644 index 000000000..93f946edd --- /dev/null +++ b/python/flashinfer/jit/core.py @@ -0,0 +1,121 @@ +import logging +import os +import re +from pathlib import Path +from typing import List, Union + +import torch.utils.cpp_extension as torch_cpp_ext +from filelock import FileLock + +from .env import CUTLASS_INCLUDE_DIRS as CUTLASS_INCLUDE_DIRS +from .env import FLASHINFER_CSRC_DIR as FLASHINFER_CSRC_DIR +from .env import FLASHINFER_GEN_SRC_DIR as FLASHINFER_GEN_SRC_DIR +from .env import FLASHINFER_INCLUDE_DIR as FLASHINFER_INCLUDE_DIR +from .env import FLASHINFER_JIT_DIR as FLASHINFER_JIT_DIR +from .env import FLASHINFER_WORKSPACE_DIR as FLASHINFER_WORKSPACE_DIR + +if not os.path.exists(FLASHINFER_WORKSPACE_DIR): + os.makedirs(FLASHINFER_WORKSPACE_DIR) + + +class FlashInferJITLogger(logging.Logger): + def __init__(self, name): + super().__init__(name) + self.setLevel(logging.INFO) + self.addHandler(logging.StreamHandler()) + log_path = FLASHINFER_WORKSPACE_DIR / "flashinfer_jit.log" + if not os.path.exists(log_path): + # create an empty file + with open(log_path, "w") as f: + pass + self.addHandler(logging.FileHandler(log_path)) + # set the format of the log + self.handlers[0].setFormatter( + logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ) + self.handlers[1].setFormatter( + logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ) + + def info(self, msg): + super().info("flashinfer.jit: " + msg) + + +logger = FlashInferJITLogger("flashinfer.jit") + + +def check_cuda_arch(): + # cuda arch check for fp8 at the moment. + for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): + arch = int(re.search(r"compute_(\d+)", cuda_arch_flags).group(1)) + if arch < 75: + raise RuntimeError("FlashInfer requires sm75+") + + +def clear_cache_dir(): + if os.path.exists(FLASHINFER_JIT_DIR): + for file in os.listdir(FLASHINFER_JIT_DIR): + os.remove(os.path.join(FLASHINFER_JIT_DIR, file)) + + +def remove_unwanted_pytorch_nvcc_flags(): + REMOVE_NVCC_FLAGS = [ + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + for flag in REMOVE_NVCC_FLAGS: + try: + torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) + except ValueError: + pass + + +remove_unwanted_pytorch_nvcc_flags() + + +def load_cuda_ops( + name: str, + sources: List[Union[str, Path]], + extra_cflags: List[str] = [], + extra_cuda_cflags: List[str] = [], + extra_ldflags=None, + extra_include_paths=None, + verbose=False, +): + cflags = ["-O3", "-Wno-switch-bool"] + cuda_cflags = [ + "-O3", + "-std=c++17", + "--threads", + "4", + "-use_fast_math", + "-DFLASHINFER_ENABLE_BF16", + "-DFLASHINFER_ENABLE_FP8", + ] + cflags += extra_cflags + cuda_cflags += extra_cuda_cflags + logger.info(f"Loading JIT ops: {name}") + check_cuda_arch() + build_directory = FLASHINFER_JIT_DIR / name + if not os.path.exists(build_directory): + os.makedirs(build_directory, exist_ok=True) + if extra_include_paths is None: + extra_include_paths = [ + FLASHINFER_INCLUDE_DIR, + FLASHINFER_CSRC_DIR, + ] + CUTLASS_INCLUDE_DIRS + lock = FileLock(FLASHINFER_JIT_DIR / f"{name}.lock", thread_local=False) + with lock: + return torch_cpp_ext.load( + name, + list(map(lambda _: str(_), sources)), + extra_cflags=cflags, + extra_cuda_cflags=cuda_cflags, + extra_ldflags=extra_ldflags, + extra_include_paths=list(map(lambda _: str(_), extra_include_paths)), + build_directory=build_directory, + verbose=verbose, + with_cuda=True, + ) diff --git a/python/flashinfer/jit/single_decode_templ.py b/python/flashinfer/jit/single_decode_templ.py index 89a2ed7e2..45c2956fa 100644 --- a/python/flashinfer/jit/single_decode_templ.py +++ b/python/flashinfer/jit/single_decode_templ.py @@ -14,12 +14,19 @@ limitations under the License. """ -customizable_single_decode_templ = r""" -#include +single_decode_suffix = [ + ".cu", + "_pybind.cc", +] + +customizable_single_decode_templ = [ + r""" #include #include #include "pytorch_extension_utils.h" +using namespace flashinfer; + struct SingleDecodeParams { using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; @@ -91,9 +98,10 @@ {{ variant_decl }} -torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, - unsigned int layout, int window_left{{ additional_func_params }}) { +void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor tmp, at::Tensor o, + unsigned int layout, int window_left{{ additional_func_params }}, + int64_t cuda_stream) { auto device = q.device(); unsigned int num_qo_heads = q.size(0); unsigned int head_dim = q.size(1); @@ -106,9 +114,6 @@ num_kv_heads = k.size(0); kv_len = k.size(1); } - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q); using ParamsT = SingleDecodeParams; using AttentionVariant = {{ variant_name }}; @@ -117,35 +122,44 @@ static_cast<{{ dtype_kv }}*>(v.data_ptr()), static_cast<{{ dtype_o }}*>(o.data_ptr()), kv_len, num_qo_heads, num_kv_heads, kv_layout, head_dim, window_left{{ additional_params_data }}); + cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = SingleDecodeWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, AttentionVariant>( - params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), stream); TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); - - return o; } +""", + r"""#include "pytorch_extension_utils.h" + +void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor tmp, at::Tensor o, + unsigned int layout, int window_left{{ additional_func_params }}, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("run", &single_decode_with_kv_cache, "Single-request decode with KV-Cache operator"); } -""" +""", +] -single_decode_templ = r""" -#include +single_decode_templ = [ + r""" #include #include #include #include #include "pytorch_extension_utils.h" +using namespace flashinfer; + {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} -torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, std::optional alibi_slopes, - unsigned int layout, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta) { +void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor tmp, std::optional alibi_slopes, + at::Tensor o, unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, int64_t cuda_stream) { auto device = q.device(); unsigned int num_qo_heads = q.size(0); unsigned int head_dim = q.size(1); @@ -158,10 +172,8 @@ num_kv_heads = k.size(0); kv_len = k.size(1); } - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q); + cudaStream_t stream = reinterpret_cast(cuda_stream); using ParamsT = SingleDecodeParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; using AttentionVariant = ComposedAttention; ParamsT params( @@ -172,16 +184,22 @@ logits_soft_cap, sm_scale, rope_scale, rope_theta); cudaError_t status = SingleDecodeWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, AttentionVariant>( - params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), stream); TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); - - return o; } +""", + r"""#include "pytorch_extension_utils.h" +void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor tmp, std::optional alibi_slopes, + at::Tensor o, unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("run", &single_decode_with_kv_cache, "Single-request decode with KV-Cache operator"); } -""" +""", +] diff --git a/python/flashinfer/jit/single_prefill_templ.py b/python/flashinfer/jit/single_prefill_templ.py index d9ca1a3a3..83e5ebcb6 100644 --- a/python/flashinfer/jit/single_prefill_templ.py +++ b/python/flashinfer/jit/single_prefill_templ.py @@ -14,8 +14,13 @@ limitations under the License. """ -customizable_single_prefill_templ = r""" -#include +single_prefill_suffix = [ + ".cu", + "_pybind.cc", +] + +customizable_single_prefill_templ = [ + r""" #include #include #include "pytorch_extension_utils.h" @@ -81,9 +86,11 @@ {{ variant_decl }} -torch::Tensor single_prefill_with_kv_cache( - unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, unsigned int layout, int32_t window_left, std::optional maybe_lse{{ additional_func_params }}) { +at::Tensor single_prefill_with_kv_cache( + unsigned int mask_mode_code, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor tmp, at::Tensor o, unsigned int layout, int32_t window_left, + std::optional maybe_lse{{ additional_func_params }}, + int64_t cuda_stream) { auto device = q.device(); unsigned int head_dim = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; @@ -102,14 +109,11 @@ kv_stride_h = k.stride(0); kv_stride_n = k.stride(1); } - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q, q.options()); + if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); - TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } using ParamsT = SinglePrefillParams; @@ -123,11 +127,11 @@ kv_stride_n, kv_stride_h, head_dim, window_left{{ additional_params_data }}); MaskMode mask_mode = static_cast(mask_mode_code); - + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { cudaError_t status = SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, MASK_MODE, AttentionVariant>( - params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); @@ -135,15 +139,24 @@ return o; } +""", + r"""#include "pytorch_extension_utils.h" + +at::Tensor single_prefill_with_kv_cache( + unsigned int mask_mode_code, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor tmp, at::Tensor o, unsigned int layout, int32_t window_left, + std::optional maybe_lse{{ additional_func_params }}, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("run", &single_prefill_with_kv_cache, "Single-request prefill attention with KV-Cache operator"); } -""" +""", +] -single_prefill_templ = r""" -#include +single_prefill_templ = [ + r""" #include #include #include @@ -155,11 +168,13 @@ {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; -torch::Tensor single_prefill_with_kv_cache( +void single_prefill_with_kv_cache( unsigned int mask_mode_code, - torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional maybe_packed_custom_mask, - torch::Tensor tmp, std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse) { + at::Tensor q, at::Tensor k, at::Tensor v, std::optional maybe_packed_custom_mask, + at::Tensor tmp, std::optional maybe_alibi_slopes, at::Tensor o, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, + int64_t cuda_stream) { auto device = q.device(); unsigned int head_dim = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; @@ -178,14 +193,10 @@ kv_stride_h = k.stride(0); kv_stride_n = k.stride(1); } - const at::cuda::OptionalCUDAGuard device_guard(device); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q, q.options()); if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); - TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } ParamsT params( @@ -199,25 +210,33 @@ kv_stride_n, kv_stride_h, head_dim, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); - MaskMode mask_mode = static_cast(mask_mode_code); - + cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; using AttentionVariant = ComposedAttention; cudaError_t status = SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, AttentionVariant>( - params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); }); - - return o; } +""", + r"""#include "pytorch_extension_utils.h" + +void single_prefill_with_kv_cache( + unsigned int mask_mode_code, + at::Tensor q, at::Tensor k, at::Tensor v, std::optional maybe_packed_custom_mask, + at::Tensor tmp, std::optional maybe_alibi_slopes, at::Tensor o, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, + int64_t cuda_stream); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("run", &single_prefill_with_kv_cache, "Single-request prefill attention with KV-Cache operator"); } -""" +""", +] diff --git a/python/flashinfer/jit/utils.py b/python/flashinfer/jit/utils.py index c96557b78..b35d216e8 100644 --- a/python/flashinfer/jit/utils.py +++ b/python/flashinfer/jit/utils.py @@ -24,6 +24,8 @@ def write_if_different(path: pathlib.Path, content: str) -> None: with open(path, "r") as f: if f.read() == content: return + else: + path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: f.write(content) diff --git a/python/flashinfer/norm.py b/python/flashinfer/norm.py index 3cbb081a0..1919296fb 100644 --- a/python/flashinfer/norm.py +++ b/python/flashinfer/norm.py @@ -19,7 +19,7 @@ import torch from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops -from .utils import register_custom_op, register_fake_op +from .utils import get_cuda_stream, register_custom_op, register_fake_op _norm_module = None @@ -78,7 +78,8 @@ def rmsnorm( def _rmsnorm( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float ) -> None: - get_norm_module().rmsnorm(out, input, weight, eps) + with input.device as device: # device guard + get_norm_module().rmsnorm(out, input, weight, eps, get_cuda_stream(device)) @register_fake_op("flashinfer::rmsnorm") @@ -111,7 +112,10 @@ def fused_add_rmsnorm( eps: float Epsilon for numerical stability. """ - get_norm_module().fused_add_rmsnorm(input, residual, weight, eps) + with input.device as device: # device guard + get_norm_module().fused_add_rmsnorm( + input, residual, weight, eps, get_cuda_stream(device) + ) @register_fake_op("flashinfer::fused_add_rmsnorm") @@ -157,7 +161,10 @@ def gemma_rmsnorm( def _gemma_rmsnorm( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float ) -> None: - get_norm_module().gemma_rmsnorm(out, input, weight, eps) + with input.device as device: # device guard + get_norm_module().gemma_rmsnorm( + out, input, weight, eps, get_cuda_stream(device) + ) @register_fake_op("flashinfer::gemma_rmsnorm") @@ -192,7 +199,10 @@ def gemma_fused_add_rmsnorm( eps: float Epsilon for numerical stability. """ - get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps) + with input.device as device: + get_norm_module().gemma_fused_add_rmsnorm( + input, residual, weight, eps, get_cuda_stream(device) + ) @register_fake_op("flashinfer::gemma_fused_add_rmsnorm") diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index 7abfb0ba4..a008fb085 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -25,6 +25,7 @@ TensorLayout, _check_kv_layout, _unpack_paged_kv_cache, + get_cuda_stream, register_custom_op, register_fake_op, ) @@ -66,18 +67,25 @@ def _append_paged_kv_cache_kernel( kv_last_page_len: torch.Tensor, layout: int, ) -> None: - get_page_module().append_paged_kv_cache( - append_key, - append_value, - batch_indices, - positions, - paged_k_cache, - paged_v_cache, - kv_indices, - kv_indptr, - kv_last_page_len, - layout, - ) + with append_key.device as device: + batch_indices = batch_indices.int() + positions = positions.int() + kv_indices = kv_indices.int() + kv_indptr = kv_indptr.int() + kv_last_page_len = kv_last_page_len.int() + get_page_module().append_paged_kv_cache( + append_key, + append_value, + batch_indices, + positions, + paged_k_cache, + paged_v_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + layout, + get_cuda_stream(device), + ) @register_fake_op("flashinfer::append_paged_kv_cache") diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index ec6515be0..cf43729a6 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -23,8 +23,8 @@ import torch from .jit import ( - gen_batch_prefill_cu, - gen_single_prefill_cu, + gen_batch_prefill_module, + gen_single_prefill_module, get_batch_prefill_uri, get_single_prefill_uri, has_prebuilt_ops, @@ -43,36 +43,12 @@ _get_cache_buf, _unpack_paged_kv_cache, canonicalize_torch_dtype, + get_cuda_stream, is_float8, register_custom_op, register_fake_op, ) - -def compile_single_prefill_module( - *args, - verbose: bool = False, -): - uri, path = gen_single_prefill_cu(*args) - return load_cuda_ops( - uri, - [path], - verbose=verbose, - ) - - -def compile_batch_prefill_module( - *args, - verbose: bool = False, -): - uri, path = gen_batch_prefill_cu(*args) - return load_cuda_ops( - uri, - [path], - verbose=verbose, - ) - - _single_prefill_modules = {} _batch_prefill_modules = {} @@ -86,7 +62,7 @@ def get_single_prefill_module(*args): run_func = _kernels.single_prefill_with_kv_cache else: - run_func = compile_single_prefill_module(*args).run + run_func = gen_single_prefill_module(*args).run # torch library for single_prefill_with_kv_cache @@ -107,22 +83,27 @@ def run_single_prefill( rope_theta: float, maybe_lse: Optional[torch.Tensor], ) -> torch.Tensor: - return run_func( - mask_mode, - q, - k, - v, - maybe_packed_custom_mask, - tmp, - maybe_alibi_slopes, - layout, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - maybe_lse, - ) + with q.device as device: # device guard + o = torch.empty_like(q) + run_func( + mask_mode, + q, + k, + v, + maybe_packed_custom_mask, + tmp, + maybe_alibi_slopes, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o @register_fake_op(f"flashinfer::{uri}_run") def _fake_run_single_prefill( @@ -165,7 +146,7 @@ def get_batch_prefill_module(*args): ragged_run_func = _kernels.batch_prefill_with_ragged_kv_cache_run paged_run_func = _kernels.batch_prefill_with_paged_kv_cache_run else: - module = compile_batch_prefill_module(*args) + module = gen_batch_prefill_module(*args) plan_func = module.plan ragged_run_func = module.ragged_run paged_run_func = module.paged_run @@ -201,27 +182,32 @@ def ragged_run( rope_theta: float, maybe_lse: Optional[torch.Tensor], ) -> torch.Tensor: - return ragged_run_func( - mask_mode, - float_workspace_buffer, - int_workspace_buffer, - plan_info_vec, - q, - k, - v, - maybe_custom_mask, - maybe_alibi_slopes, - qo_indptr, - kv_indptr, - maybe_qk_indptr, - layout, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - maybe_lse, - ) + with q.device as device: # device guard + o = torch.empty_like(q) + ragged_run_func( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + k, + v, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + kv_indptr, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o @register_fake_op(f"flashinfer::{uri}_ragged_run") def _fake_ragged_run( @@ -282,29 +268,34 @@ def paged_run( rope_theta: float, maybe_lse: Optional[torch.Tensor], ) -> torch.Tensor: - return paged_run_func( - mask_mode, - float_workspace_buffer, - int_workspace_buffer, - plan_info_vec, - q, - paged_k_cache, - paged_v_cache, - maybe_custom_mask, - maybe_alibi_slopes, - qo_indptr, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, - maybe_qk_indptr, - layout, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - maybe_lse, - ) + with q.device as device: # device guard + o = torch.empty_like(q) + paged_run_func( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o @register_fake_op(f"flashinfer::{uri}_paged_run") def _fake_paged_run( @@ -355,14 +346,30 @@ def single_prefill_with_kv_cache_with_jit_module( window_left: int = -1, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) - lse = None - if return_lse: - lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) - out = jit_module.run( - mask_mode, q, k, v, tmp, TensorLayout[kv_layout].value, window_left, lse, *args - ) - return (out, lse) if return_lse else out + with q.device as device: # device guard + tmp = _get_cache_buf( + "single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, device=device + ) + o = torch.empty_like(q) + lse = None + if return_lse: + lse = torch.empty( + (q.size(0), q.size(1)), dtype=torch.float32, device=device + ) + jit_module.run( + mask_mode, + q, + k, + v, + tmp, + o, + TensorLayout[kv_layout].value, + window_left, + lse, + *args, + get_cuda_stream(device), + ) + return (o, lse) if return_lse else o @overload @@ -1053,18 +1060,20 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, ) - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - paged_kv_indptr_host, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - ) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + paged_kv_indptr_host, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction @@ -1640,18 +1649,20 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, ) - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - kv_indptr_host, - batch_size, - num_qo_heads, - num_kv_heads, - 1, # page_size - self.is_cuda_graph_enabled, - ) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + batch_size, + num_qo_heads, + num_kv_heads, + 1, # page_size + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction diff --git a/python/flashinfer/quantization.py b/python/flashinfer/quantization.py index 3bd2f6a6d..f5c00340b 100644 --- a/python/flashinfer/quantization.py +++ b/python/flashinfer/quantization.py @@ -19,7 +19,7 @@ import torch from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops -from .utils import register_custom_op, register_fake_op +from .utils import get_cuda_stream, register_custom_op, register_fake_op _quantization_module = None @@ -44,7 +44,11 @@ def get_quantization_module(): @register_custom_op("flashinfer::packbits", mutates_args=()) def _packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor: - return get_quantization_module().packbits(x, bitorder) + with x.device as device: # device guard + x = x.to(torch.bool) + y = torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=device) + get_quantization_module().packbits(x, bitorder, y, get_cuda_stream(device)) + return y @register_fake_op("flashinfer::packbits") @@ -136,7 +140,16 @@ def segment_packbits( packed_len = (seglen + 7) // 8 indptr_new = torch.zeros(len(indptr), dtype=indptr.dtype, device=indptr.device) indptr_new[1:] = torch.cumsum(packed_len, 0) - return ( - get_quantization_module().segment_packbits(x, indptr, indptr_new, bitorder), - indptr_new, - ) + output_nnzs = indptr_new[-1].item() + + with x.device as device: + indptr = indptr.to(torch.int32) + indptr_new = indptr_new.to(torch.int32) + y = torch.empty(output_nnzs, dtype=torch.uint8, device=device) + get_quantization_module().segment_packbits( + x, indptr, indptr_new, bitorder, y, get_cuda_stream(device) + ) + return ( + y, + indptr_new, + ) diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index ef2f20b2f..17266b4e9 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -19,7 +19,7 @@ import torch from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops -from .utils import register_custom_op, register_fake_op +from .utils import get_cuda_stream, register_custom_op, register_fake_op _rope_module = None @@ -55,18 +55,20 @@ def _apply_rope( rope_scale: float, rope_theta: float, ) -> None: - get_rope_module().apply_rope( - q, - k, - q_rope, - k_rope, - indptr, - offsets, - rotary_dim, - interleave, - rope_scale, - rope_theta, - ) + with q.device as device: + get_rope_module().apply_rope( + q, + k, + q_rope, + k_rope, + indptr, + offsets, + rotary_dim, + interleave, + rope_scale, + rope_theta, + get_cuda_stream(device), + ) @register_fake_op("flashinfer::apply_rope") @@ -101,21 +103,23 @@ def _apply_llama31_rope( high_freq_factor: float, old_context_len: float, ) -> None: - get_rope_module().apply_llama31_rope( - q, - k, - q_rope, - k_rope, - indptr, - offsets, - rotary_dim, - interleave, - rope_scale, - rope_theta, - low_freq_factor, - high_freq_factor, - old_context_len, - ) + with q.device as device: + get_rope_module().apply_llama31_rope( + q, + k, + q_rope, + k_rope, + indptr, + offsets, + rotary_dim, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + old_context_len, + get_cuda_stream(device), + ) @register_fake_op("flashinfer::apply_llama31_rope") @@ -149,9 +153,19 @@ def _apply_rope_pos_ids( rope_scale: float, rope_theta: float, ) -> None: - get_rope_module().apply_rope_pos_ids( - q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta - ) + with q.device as device: + get_rope_module().apply_rope_pos_ids( + q, + k, + q_rope, + k_rope, + pos_ids, + rotary_dim, + interleave, + rope_scale, + rope_theta, + get_cuda_stream(device), + ) @register_fake_op("flashinfer::apply_rope_pos_ids") @@ -182,16 +196,18 @@ def _apply_rope_pos_ids_cos_sin_cache( pos_ids: torch.Tensor, interleave: bool, ) -> None: - get_rope_module().apply_rope_pos_ids_cos_sin_cache( - q, - k, - q_rope, - k_rope, - cos_cache, - sin_cache, - pos_ids, - interleave, - ) + with q.device as device: + get_rope_module().apply_rope_pos_ids_cos_sin_cache( + q, + k, + q_rope, + k_rope, + cos_cache, + sin_cache, + pos_ids, + interleave, + get_cuda_stream(device), + ) @register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache") @@ -225,20 +241,22 @@ def _apply_llama31_rope_pos_ids( high_freq_factor: float, old_context_len: float, ) -> None: - get_rope_module().apply_llama31_rope_pos_ids( - q, - k, - q_rope, - k_rope, - pos_ids, - rotary_dim, - interleave, - rope_scale, - rope_theta, - low_freq_factor, - high_freq_factor, - old_context_len, - ) + with q.device as device: + get_rope_module().apply_llama31_rope_pos_ids( + q, + k, + q_rope, + k_rope, + pos_ids, + rotary_dim, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + old_context_len, + get_cuda_stream(device), + ) @register_fake_op("flashinfer::apply_llama31_rope_pos_ids") diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index 0d59edf4f..896721c60 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -20,7 +20,7 @@ import torch from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops -from .utils import register_custom_op, register_fake_op +from .utils import get_cuda_stream, register_custom_op, register_fake_op _sampling_module = None @@ -49,7 +49,18 @@ def sampling_from_probs( uniform_samples: torch.Tensor, deterministic: bool, ) -> torch.Tensor: - return module.sampling_from_probs(probs, uniform_samples, deterministic) + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + module.sampling_from_probs( + probs, + uniform_samples, + samples, + deterministic, + get_cuda_stream(device), + ) + return samples @register_fake_op("flashinfer::sampling_from_probs") def _fake_sampling_from_probs( @@ -69,10 +80,25 @@ def top_p_sampling_from_probs( top_p_val: float, deterministic: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - samples, success = module.top_p_sampling_from_probs( - probs, uniform_samples, maybe_top_p_arr, top_p_val, deterministic - ) - return samples, success + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + module.top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_p_arr, + top_p_val, + deterministic, + get_cuda_stream(device), + ) + return samples, success @register_fake_op("flashinfer::top_p_sampling_from_probs") def _fake_top_p_sampling_from_probs( @@ -96,10 +122,25 @@ def top_k_sampling_from_probs( top_k_val: int, deterministic: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - samples, success = module.top_k_sampling_from_probs( - probs, uniform_samples, maybe_top_k_arr, top_k_val, deterministic - ) - return samples, success + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_k_arr = ( + maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + module.top_k_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_k_arr, + top_k_val, + deterministic, + get_cuda_stream(device), + ) + return samples, success @register_fake_op("flashinfer::top_k_sampling_from_probs") def _fake_top_k_sampling_from_probs( @@ -123,10 +164,25 @@ def min_p_sampling_from_probs( min_p_val: float, deterministic: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - samples, success = module.min_p_sampling_from_probs( - probs, uniform_samples, maybe_min_p_arr, min_p_val, deterministic - ) - return samples, success + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_min_p_arr = ( + maybe_min_p_arr.float() if maybe_min_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + module.min_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_min_p_arr, + min_p_val, + deterministic, + get_cuda_stream(device), + ) + return samples, success # torch library for top_k_top_p_sampling_from_probs @@ -142,16 +198,30 @@ def top_k_top_p_sampling_from_probs( top_p_val: float, deterministic: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - samples, success = module.top_k_top_p_sampling_from_probs( - probs, - uniform_samples, - maybe_top_k_arr, - top_k_val, - maybe_top_p_arr, - top_p_val, - deterministic, - ) - return samples, success + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_k_arr = ( + maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + ) + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + module.top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + get_cuda_stream(device), + ) + return samples, success @register_fake_op("flashinfer::top_k_top_p_sampling_from_probs") def _fake_top_k_top_p_sampling_from_probs( @@ -175,7 +245,20 @@ def top_p_renorm_probs( maybe_top_p_arr: Optional[torch.Tensor], top_p_val: float, ) -> torch.Tensor: - return module.top_p_renorm_probs(probs, maybe_top_p_arr, top_p_val) + with probs.device as device: + probs = probs.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + renorm_probs = torch.empty_like(probs) + module.top_p_renorm_probs( + probs, + renorm_probs, + maybe_top_p_arr, + top_p_val, + get_cuda_stream(device), + ) + return renorm_probs @register_fake_op("flashinfer::top_p_renorm_probs") def _fake_top_p_renorm_probs( @@ -193,7 +276,20 @@ def top_k_renorm_probs( maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, ) -> torch.Tensor: - return module.top_k_renorm_probs(probs, maybe_top_k_arr, top_k_val) + with probs.device as device: + probs = probs.float() + maybe_top_k_arr = ( + maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + ) + renorm_probs = torch.empty_like(probs) + module.top_k_renorm_probs( + probs, + renorm_probs, + maybe_top_k_arr, + top_k_val, + get_cuda_stream(device), + ) + return renorm_probs @register_fake_op("flashinfer::top_k_renorm_probs") def _fake_top_k_renorm_probs( @@ -211,7 +307,20 @@ def top_k_mask_logits( maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, ) -> torch.Tensor: - return module.top_k_mask_logits(logits, maybe_top_k_arr, top_k_val) + with logits.device as device: + logits = logits.float() + maybe_top_k_arr = ( + maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + ) + mask_logits = torch.empty_like(logits) + module.top_k_mask_logits( + logits, + mask_logits, + maybe_top_k_arr, + top_k_val, + get_cuda_stream(device), + ) + return mask_logits @register_fake_op("flashinfer::top_k_mask_logits") def _fake_top_k_mask_logits( @@ -236,15 +345,29 @@ def chain_speculative_sampling( output_emitted_token_num: torch.Tensor, deterministic: bool, ) -> torch.Tensor: - return module.chain_speculative_sampling( - draft_probs, - draft_token_ids, - uniform_samples, - target_probs, - output_accepted_token_num, - output_emitted_token_num, - deterministic, - ) + with draft_probs.device as device: + draft_probs = draft_probs.float() + draft_token_ids = draft_token_ids.int() + uniform_samples = uniform_samples.float() + target_probs = target_probs.float() + output_accepted_token_num = output_accepted_token_num.int() + output_emitted_token_num = output_emitted_token_num.int() + b, n = draft_token_ids.shape + output_token_ids = torch.empty( + (b, n + 1), dtype=torch.int32, device=device + ) + module.chain_speculative_sampling( + draft_probs, + draft_token_ids, + uniform_samples, + target_probs, + output_token_ids, + output_accepted_token_num, + output_emitted_token_num, + deterministic, + get_cuda_stream(device), + ) + return output_token_ids @register_fake_op("flashinfer::chain_speculative_sampling") def _fake_chain_speculative_sampling( diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py index 19c4734e5..1232538f7 100644 --- a/python/flashinfer/sparse.py +++ b/python/flashinfer/sparse.py @@ -30,6 +30,7 @@ _check_pos_encoding_mode, _get_cache_alibi_slopes_buf, canonicalize_torch_dtype, + get_cuda_stream, ) @@ -323,17 +324,19 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap ) - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - kv_indptr_host, - num_blocks_row, - num_qo_heads, - num_kv_heads, - C, - False, # is_cuda_graph_enabled - ) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + kv_indptr_host, + num_blocks_row, + num_qo_heads, + num_kv_heads, + C, + False, # is_cuda_graph_enabled + get_cuda_stream(device), + ) else: # if the operation is compute-bound, we use the tensor-core implementation self._use_tensor_cores = True @@ -349,18 +352,20 @@ def plan( allow_fp16_qk_reduction, ) - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - kv_indptr_host, - num_blocks_row, - num_qo_heads, - num_kv_heads, - C, - False, # is_cuda_graph_enabled - ) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + num_blocks_row, + num_qo_heads, + num_kv_heads, + C, + False, # is_cuda_graph_enabled + get_cuda_stream(device), + ) self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 55ee36037..69ef89215 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -248,3 +248,7 @@ def register_fake_op( fn: Optional[Callable] = None, ) -> Callable: return torch.library.register_fake(name, fn) + + +def get_cuda_stream(device: torch.device) -> int: + return torch.cuda.current_stream(device).cuda_stream diff --git a/python/jit_MANIFEST.in b/python/jit_MANIFEST.in index 330225759..d2c609dc4 100644 --- a/python/jit_MANIFEST.in +++ b/python/jit_MANIFEST.in @@ -4,7 +4,6 @@ global-exclude *.so prune */__pycache__ prune csrc -prune csrc_aot exclude flashinfer/jit/aot_config.py exclude aot_setup.py exclude mypy.ini diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index f9f931ea7..05f07f08f 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -184,5 +184,5 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { .add_int64_axis("num_qo_heads", {32}) \ .add_int64_axis("num_kv_heads", {32, 4}) -kENCH_FLASHINFER_BATCH_DECODE(half, half); +BENCH_FLASHINFER_BATCH_DECODE(half, half); BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(half, half); diff --git a/src/cpu_reference.h b/src/cpu_reference.h index 73174da73..d3d418de4 100644 --- a/src/cpu_reference.h +++ b/src/cpu_reference.h @@ -15,9 +15,10 @@ */ #pragma once +#include + #include #include -#include #include "utils.h" @@ -120,7 +121,7 @@ std::vector single_mha(const std::vector& q, const std::vect default: { std::ostringstream err_msg; err_msg << "Unsupported rotary mode."; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } } // apply mask diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index ac8c9db52..142add1a4 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -22,6 +22,7 @@ #include "flashinfer/allocator.h" #include "flashinfer/attention/mask.cuh" #include "flashinfer/attention/scheduler.cuh" +#include "flashinfer/exception.h" #include "flashinfer/layout.cuh" #include "utils.h" @@ -495,7 +496,7 @@ cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* o std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " << num_kv_heads; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } DISPATCH_head_dim( @@ -546,7 +547,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " << num_kv_heads; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } DISPATCH_head_dim( @@ -585,7 +586,7 @@ cudaError_t BatchDecodeHandlerPlan(BatchDecodeHandler* handler, void* float_buff std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " << num_kv_heads; - throw std::invalid_argument(err_msg.str()); + FLASHINFER_ERROR(err_msg.str()); } DISPATCH_head_dim(head_dim, HEAD_DIM, { DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { diff --git a/src/utils.h b/src/utils.h index 17501e196..9b236b8ed 100644 --- a/src/utils.h +++ b/src/utils.h @@ -28,9 +28,9 @@ #include #include -#include #include "dispatch.inc" +#include "flashinfer/exception.h" #define _DISPATCH_SWITCH(var_name, cond, ...) \ switch (cond) { \ @@ -38,7 +38,7 @@ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ - throw std::invalid_argument(oss.str()); \ + FLASHINFER_ERROR(oss.str()); \ } #define _DISPATCH_CASE(case_expr, case_var, ...) \ diff --git a/tests/test_activation.py b/tests/test_activation.py index ef85c72f5..8bfe479a9 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -48,3 +48,6 @@ def test_fused_gelu_mul(dim, batch_size, seq_len): y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none") y = flashinfer.activation.gelu_and_mul(x) torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +test_fused_silu_mul(128, 1, 1) diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index 241aa6582..f65f33599 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -5,8 +5,10 @@ import torch from flashinfer.decode import single_decode_with_kv_cache_with_jit_module from flashinfer.jit.attention import ( - get_customize_single_decode_cu_str, - get_customize_single_prefill_cu_str, + gen_customize_single_decode_module, + gen_customize_single_prefill_module, + single_decode_suffix, + single_prefill_suffix, ) from flashinfer.prefill import single_prefill_with_kv_cache_with_jit_module from flashinfer.utils import MaskMode @@ -14,19 +16,6 @@ import flashinfer -def generate_module(module_name: str, cuda_ops_str: str): - gen_directory = flashinfer.jit.FLASHINFER_GEN_SRC_DIR - flashinfer.jit.utils.write_if_different( - gen_directory / f"{module_name}.cu", - cuda_ops_str, - ) - - return flashinfer.jit.load_cuda_ops( - module_name, - [gen_directory / f"{module_name}.cu"], - ) - - def test_single_decode_mask(): torch.manual_seed(42) variant_decl = r""" @@ -70,7 +59,8 @@ def test_single_decode_mask(): } }; """ - cuda_ops_str = get_customize_single_decode_cu_str( + jit_module = gen_customize_single_decode_module( + "single_decode_with_custom_mask", torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o @@ -83,7 +73,6 @@ def test_single_decode_mask(): variant_decl, ) - jit_module = generate_module("single_decode_with_custom_mask", cuda_ops_str) f = functools.partial(single_decode_with_kv_cache_with_jit_module, jit_module) q = torch.randn(32, 128, dtype=torch.float16, device="cuda") @@ -145,7 +134,8 @@ def test_flash_sigmoid(): } }; """ - cuda_ops_str = get_customize_single_prefill_cu_str( + jit_module = gen_customize_single_prefill_module( + "flash_sigmoid", torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o @@ -158,7 +148,6 @@ def test_flash_sigmoid(): variant_decl, ) - jit_module = generate_module("flash_sigmoid", cuda_ops_str) f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module) q = torch.randn(128, 8, 128, dtype=torch.float16, device="cuda") @@ -219,7 +208,8 @@ def test_dump_logits(): } }; """ - cuda_ops_str = get_customize_single_prefill_cu_str( + jit_module = gen_customize_single_prefill_module( + "dump_logits", torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o @@ -232,7 +222,6 @@ def test_dump_logits(): variant_decl, ) - jit_module = generate_module("dump_logits", cuda_ops_str) f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module) q = torch.randn(128, 32, 128, dtype=torch.float16, device="cuda") @@ -293,7 +282,8 @@ def test_debug_print_logits(): } }; """ - cuda_ops_str = get_customize_single_prefill_cu_str( + jit_module = gen_customize_single_prefill_module( + "debug_print_logits", torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o @@ -306,7 +296,6 @@ def test_debug_print_logits(): variant_decl, ) - jit_module = generate_module("debug_print_logits", cuda_ops_str) f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module) q = torch.randn(128, 32, 128, dtype=torch.float16, device="cuda") diff --git a/tests/test_non_contiguous_decode.py b/tests/test_non_contiguous_decode.py index 9d9e14da9..8fdc0ac63 100644 --- a/tests/test_non_contiguous_decode.py +++ b/tests/test_non_contiguous_decode.py @@ -72,7 +72,3 @@ def test_batch_paged_decode_packed_input( o_packed = wrapper.run(q, paged_kv_cache) o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache) torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) - - -if __name__ == "__main__": - test_batch_paged_decode_packed_input(37, 127, 1, 4, 64, 128) diff --git a/tests/test_rope.py b/tests/test_rope.py index c270d1e94..7fab23617 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -348,6 +348,6 @@ def test_rope_cos_sin_cache( if __name__ == "__main__": - test_rope(2, 1, 8, 8, 1, 128, "llama31", 1.0, False) - test_rope_pos_ids(2, 1, 8, 8, 1, 128, "llama31", 1.0, False) - test_rope_cos_sin_cache(99, 19, 16, 8, 99, 256, "llama31", 0.5, False) + test_rope(2, 1, 8, 8, 1, 128, "llama", 1.0, False) + # test_rope_pos_ids(2, 1, 8, 8, 1, 128, "llama31", 1.0, False) + # test_rope_cos_sin_cache(99, 19, 16, 8, 99, 256, "llama31", 0.5, False) diff --git a/tests/test_shared_prefix_kernels.py b/tests/test_shared_prefix_kernels.py index af478fc39..8ec9deb32 100644 --- a/tests/test_shared_prefix_kernels.py +++ b/tests/test_shared_prefix_kernels.py @@ -80,7 +80,11 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( flashinfer.append_paged_kv_cache( k_shared, v_shared, - shared_append_indptr, + *flashinfer.get_batch_indices_positions( + shared_append_indptr, + flashinfer.get_seq_lens(shared_kv_indptr, shared_last_page_len, page_size), + k_shared.shape[0], + ), kv_data, shared_kv_indices, shared_kv_indptr, @@ -100,7 +104,11 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( flashinfer.append_paged_kv_cache( k_unique, v_unique, - unique_append_indptr, + *flashinfer.get_batch_indices_positions( + unique_append_indptr, + flashinfer.get_seq_lens(unique_kv_indptr, unique_last_page_len, page_size), + k_unique.shape[0], + ), kv_data, unique_kv_indices, unique_kv_indptr,