diff --git a/docs/api/python/decode.rst b/docs/api/python/decode.rst index f78985901..eb4d06a3c 100644 --- a/docs/api/python/decode.rst +++ b/docs/api/python/decode.rst @@ -25,5 +25,9 @@ Batch Decoding .. autoclass:: BatchDecodeWithPagedKVCacheWrapper :members: + .. automethod:: __init__ + .. autoclass:: CUDAGraphBatchDecodeWithPagedKVCacheWrapper :members: + + .. automethod:: __init__ diff --git a/docs/api/python/group_gemm.rst b/docs/api/python/group_gemm.rst new file mode 100644 index 000000000..b396a320e --- /dev/null +++ b/docs/api/python/group_gemm.rst @@ -0,0 +1,13 @@ +.. _apigroup_gemm: + +flashinfer.group_gemm +===================== + +This module provides a set of functions to group GEMM operations. + +.. currentmodule:: flashinfer.group_gemm + +.. autoclass:: SegmentGEMMWrapper + :members: + + .. automethod:: __init__ diff --git a/docs/api/python/prefill.rst b/docs/api/python/prefill.rst index 9f50f1953..aad6cbf65 100644 --- a/docs/api/python/prefill.rst +++ b/docs/api/python/prefill.rst @@ -22,6 +22,9 @@ Batch Prefill/Append Attention .. autoclass:: BatchPrefillWithPagedKVCacheWrapper :members: + .. automethod:: __init__ + .. autoclass:: BatchPrefillWithRaggedKVCacheWrapper :members: - \ No newline at end of file + + .. automethod:: __init__ diff --git a/docs/index.rst b/docs/index.rst index 334d3c8b6..8851b7ff1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,4 +32,5 @@ FlashInfer is a library for Language Languages Models that provides high-perform api/python/cascade api/python/page api/python/sampling + api/python/group_gemm api/python/norm diff --git a/include/flashinfer/allocator.h b/include/flashinfer/allocator.h new file mode 100644 index 000000000..e4840f167 --- /dev/null +++ b/include/flashinfer/allocator.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ALLOCATOR_H_ +#define FLASHINFER_ALLOCATOR_H_ + +#include +#include + +namespace flashinfer { + +struct AlignedAllocator { + void* ptr; + size_t space; + AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {} + template + T* aligned_alloc(size_t size, size_t alignment) { + if (std::align(alignment, size, ptr, space)) { + T* result = reinterpret_cast(ptr); + ptr = (char*)ptr + size; + space -= size; + return result; + } else { + throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor"); + } + return nullptr; + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ALLOCATOR_H_ diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 2f880f162..0fe1750a3 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -13,16 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef FLASHINFER_HANDLER_CUH_ -#define FLASHINFER_HANDLER_CUH_ +#ifndef FLASHINFER_ATTENTION_HANDLER_CUH_ +#define FLASHINFER_ATTENTION_HANDLER_CUH_ #include #include -#include #include -#include #include +#include "../allocator.h" #include "../page.cuh" #include "../pos_enc.cuh" #include "../utils.cuh" @@ -241,24 +240,6 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo( return cudaSuccess; } -struct AlignedAllocator { - void* ptr; - size_t space; - AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {} - template - T* aligned_alloc(size_t size, size_t alignment) { - if (std::align(alignment, size, ptr, space)) { - T* result = reinterpret_cast(ptr); - ptr = (char*)ptr + size; - space -= size; - return result; - } else { - throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor"); - } - return nullptr; - } -}; - class BatchDecodeHandler { public: template @@ -584,4 +565,4 @@ class BatchPrefillHandler { }; } // namespace flashinfer -#endif // FLASHINFER_HANDLER_CUH_ +#endif // FLASHINFER_ATTENTION_HANDLER_CUH_ diff --git a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh new file mode 100644 index 000000000..a3422bef9 --- /dev/null +++ b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_ +#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_ + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +namespace flashinfer { + +namespace group_gemm { + +template +struct cutlass_dtype { + using type = T; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::half_t; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::bfloat16_t; +}; + +template +__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x, + T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w, + int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr, + int64_t* w_indices, size_t d_in, size_t d_out, + bool w_column_major) { + int i = blockIdx.x; + int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out; + all_problems[i] = cutlass::gemm::GemmCoord(m, n, k); + ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out; + ptr_x[i] = x + xy_indptr[i] * d_in; + ptr_y[i] = y + xy_indptr[i] * d_out; + ld_x[i] = k; // m * k + ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major + ld_y[i] = n; // m * n +} + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GROUP_GEMM_CUTLASS_WRAPPER_CUH_ diff --git a/include/flashinfer/group_gemm/group_gemm_lora.cuh b/include/flashinfer/group_gemm/group_gemm_lora.cuh new file mode 100644 index 000000000..517419da5 --- /dev/null +++ b/include/flashinfer/group_gemm/group_gemm_lora.cuh @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GROUP_GEMM_LORA_CUH_ +#define FLASHINFER_GROUP_GEMM_LORA_CUH_ + +namespace flashinfer { + +namespace group_gemm { + +// TODO(Zihao): port punica's sgmv kernel + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GROUP_GEMM_LORA_CUH_ diff --git a/include/flashinfer/group_gemm/group_gemv.cuh b/include/flashinfer/group_gemm/group_gemv.cuh new file mode 100644 index 000000000..4b439355e --- /dev/null +++ b/include/flashinfer/group_gemm/group_gemv.cuh @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GROUP_GEMV_CUH_ +#define FLASHINFER_GROUP_GEMV_CUH_ + +namespace flashinfer { + +namespace group_gemm { + +// TODO(Zihao): port punica's bgmv kernel + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GROUP_GEMV_CUH_ diff --git a/include/flashinfer/group_gemm/handler.cuh b/include/flashinfer/group_gemm/handler.cuh new file mode 100644 index 000000000..39ef0f783 --- /dev/null +++ b/include/flashinfer/group_gemm/handler.cuh @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GROUP_GEMM_HANDLER_CUH_ +#define FLASHINFER_GROUP_GEMM_HANDLER_CUH_ + +#include + +#include "../allocator.h" +#include "../utils.cuh" +#include "group_gemm_cutlass.cuh" +#include "group_gemm_lora.cuh" +#include "group_gemv.cuh" + +namespace flashinfer { + +namespace group_gemm { + +enum class GroupGEMMKernelConfig { + kGeneral, // large d_in, d_out + kShrink, // large d_in, small d_out + kExpand, // small d_in, large d_out +}; + +class CutlassSegmentGEMMHandler { + public: + void RegisterWorkspace(void* buffer, size_t size) { + buffer_ = buffer; + workspace_size_in_bytes_ = size; + } + + void* GetWorkspace() const { return buffer_; } + + size_t GetWorkspaceSizeInBytes() const { return workspace_size_in_bytes_; } + + cudaStream_t GetCUDAStream() const { return stream_; } + + void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } + + CutlassSegmentGEMMHandler() {} + + ~CutlassSegmentGEMMHandler() {} + + private: + void* buffer_; + size_t workspace_size_in_bytes_; + cudaStream_t stream_; +}; + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GROUP_GEMM_HANDLER_CUH_ diff --git a/include/flashinfer/group_gemm/wrapper.cuh b/include/flashinfer/group_gemm/wrapper.cuh new file mode 100644 index 000000000..adc1d077b --- /dev/null +++ b/include/flashinfer/group_gemm/wrapper.cuh @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ +#define FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ + +#include + +#include "../allocator.h" +#include "handler.cuh" + +namespace flashinfer { + +namespace group_gemm { + +#define DISPATCH_WEIGHT_LAYOUT(is_column_major, WEIGHT_LAYOUT, ...) \ + if (is_column_major) { \ + using WEIGHT_LAYOUT = cutlass::layout::ColumnMajor; \ + __VA_ARGS__ \ + } else { \ + using WEIGHT_LAYOUT = cutlass::layout::RowMajor; \ + __VA_ARGS__ \ + } + +template +cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType* x, DType* w, + DType* y, int64_t* xy_indptr_d, int64_t* w_indices_d, + unsigned int batch_size, unsigned int d_in, + unsigned int d_out, bool weight_column_major, + cudaStream_t stream) { + AlignedAllocator allocator(handler->GetWorkspace(), handler->GetWorkspaceSizeInBytes()); + cutlass::gemm::GemmCoord* problem_sizes_device = + allocator.aligned_alloc( + batch_size * sizeof(cutlass::gemm::GemmCoord), 16); + DType** x_data = allocator.aligned_alloc(batch_size * sizeof(DType*), 16); + DType** w_data = allocator.aligned_alloc(batch_size * sizeof(DType*), 16); + DType** y_data = allocator.aligned_alloc(batch_size * sizeof(DType*), 16); + int64_t* ld_x = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16); + int64_t* ld_w = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16); + int64_t* ld_y = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16); + + // NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API, + // so I just use the kernel function directly, need to investigate more. + auto compute_args_kernel = compute_cutlass_group_gemm_args; + compute_args_kernel<<>>( + problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DType*)x, (DType*)w, + (DType*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "Failed to launch kernel: " << cudaGetErrorString(err) << std::endl; + return err; + } + + using cutlass::epilogue::thread::LinearCombination; + using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; + DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, { + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + DType, // Element A + cutlass::layout::RowMajor, // Layout A + cutlass::ComplexTransform::kNone, // + 8, // Granularity A + DType, // Element B + WEIGHT_LAYOUT, // Layout B + cutlass::ComplexTransform::kNone, // + 8, // Granularity B + DType, // Element C&D + cutlass::layout::RowMajor, // Layout C&D + float, // Element Accumulator + cutlass::arch::OpClassTensorOp, // Operator Class Tag + cutlass::arch::Sm80, // Architecture + cutlass::gemm::GemmShape<128, 128, 32>, // Thread Block Shape + cutlass::gemm::GemmShape<64, 64, 32>, // Warp Shape + cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape + cutlass::epilogue::thread::LinearCombination, // Epilogue + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // Swizzling Operator + 8 // Stages + >::GemmKernel; + + using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + typename GemmGrouped::Arguments args(problem_sizes_device, batch_size, 4, epilogue_op, x_data, + w_data, y_data, y_data, ld_x, ld_w, ld_y, ld_y); + + GemmGrouped gemm; + auto status = gemm.initialize(args, nullptr, stream); + if (status != cutlass::Status::kSuccess) { + std::ostringstream err_msg; + err_msg << "cutlass group_gemm.initialize failed: " << cutlassGetStatusString(status); + throw std::runtime_error(err_msg.str()); + } + status = gemm.run(stream); + if (status != cutlass::Status::kSuccess) { + std::ostringstream err_msg; + err_msg << "cutlass group_gemm.run failed: " << cutlassGetStatusString(status); + throw std::runtime_error(err_msg.str()); + } + }); + + return cudaSuccess; +} + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 358674a6d..2c977fec4 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -308,6 +308,17 @@ std::tuple, std::vector> split_qo_in return {num_frags_x, num_qo_tiles, std::move(request_indices), std::move(tile_indices)}; } +template +inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { + std::vector host_array(size); + std::cout << prefix; + cudaMemcpy(host_array.data(), device_ptr, size * sizeof(T), cudaMemcpyDeviceToHost); + for (size_t i = 0; i < size; ++i) { + std::cout << host_array[i] << " "; + } + std::cout << std::endl; +} + } // namespace flashinfer #endif // FLASHINFER_UTILS_CUH_ diff --git a/python/3rdparty b/python/3rdparty new file mode 120000 index 000000000..303a6484e --- /dev/null +++ b/python/3rdparty @@ -0,0 +1 @@ +../3rdparty \ No newline at end of file diff --git a/python/MANIFEST.in b/python/MANIFEST.in index 070a8eb79..854badc80 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -10,6 +10,7 @@ include generate_single_prefill_inst.py include literal_map.py recursive-include include * recursive-include csrc * +recursive-include 3rdparty/cutlass * # wheel-only exclude flashinfer/_build_meta.py diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index ec05cf782..13ab21dfe 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -370,8 +370,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(k); diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index b088d07e2..d784665d6 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -72,4 +72,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); + py::class_(m, "CutlassSegmentGEMMPyTorchWrapper") + .def(py::init()) + .def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer) + .def("forward", &CutlassSegmentGEMMPyTorchWrapper::Forward); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index a42cc1282..b16b6a570 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -164,3 +165,19 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { std::shared_ptr handler_; flashinfer::QKVLayout kv_layout_; }; + +class CutlassSegmentGEMMPyTorchWrapper { + public: + void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer); + + torch::Tensor Forward(torch::Tensor seg_indptr, torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, bool weight_column_major); + + CutlassSegmentGEMMPyTorchWrapper(torch::Tensor workspace_buffer) + : handler_(std::make_shared()) { + RegisterWorkspaceBuffer(workspace_buffer); + } + + private: + std::shared_ptr handler_; +}; diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu new file mode 100644 index 000000000..f8ee43887 --- /dev/null +++ b/python/csrc/group_gemm.cu @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "flashinfer_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer::group_gemm; + +void CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer(torch::Tensor workspace_buffer) { + handler_->RegisterWorkspace(static_cast(workspace_buffer.data_ptr()), + workspace_buffer.size(0) * workspace_buffer.element_size()); +} + +torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor seg_indptr, + torch::Tensor weight_indices, + torch::Tensor x, torch::Tensor weight, + unsigned int batch_size, + bool weight_column_major) { + // TODO(Zihao): Add more checks here + CHECK_CUDA(seg_indptr); + CHECK_CUDA(x); + CHECK_CUDA(weight); + CHECK_DIM(2, x); // x: [sum(m_i), d_in] + CHECK_DIM(3, weight); // weight: [num_weights, d_out, d_in] if weight_column_major, [num_weights, + // d_in, d_out] otherwise + int64_t cumulative_batch_size = x.size(0); + int64_t d_out = weight_column_major ? weight.size(1) : weight.size(2); + int64_t d_in = weight_column_major ? weight.size(2) : weight.size(1); + CHECK_EQ(x.size(1), d_in); + auto y = torch::zeros({cumulative_batch_size, d_out}, x.options()); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + seg_indptr = seg_indptr.to(torch::kInt64); + + bool weight_indices_defined = weight_indices.numel() > 0; + if (weight_indices_defined) { + CHECK_CUDA(weight_indices); + weight_indices = weight_indices.to(torch::kInt64); + } + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { + using cutlass_t = typename cutlass_dtype::type; + auto status = CutlassSegmentGEMMWrapper( + handler_.get(), static_cast(x.data_ptr()), + static_cast(weight.data_ptr()), static_cast(y.data_ptr()), + static_cast(seg_indptr.data_ptr()), + weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, + batch_size, d_in, d_out, weight_column_major, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); + return true; + }); + + return y; +} \ No newline at end of file diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 64ce3ce27..9dbf9dbe3 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -45,6 +45,7 @@ chain_speculative_sampling, ) from .norm import rmsnorm +from .group_gemm import SegmentGEMMWrapper try: from ._build_meta import __version__ as __version__ diff --git a/python/flashinfer/group_gemm.py b/python/flashinfer/group_gemm.py new file mode 100644 index 000000000..7bc8e58fa --- /dev/null +++ b/python/flashinfer/group_gemm.py @@ -0,0 +1,145 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional +from .utils import get_indptr + +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + + +class SegmentGEMMWrapper: + r"""Wrapper for segment GEMM kernels.""" + + def __init__(self, workspace_buffer: torch.Tensor): + r"""Initialize the wrapper. + + Parameters + ---------- + workspace_buffer : torch.Tensor + The workspace buffer for the kernels, we use it to store the metadata for the segment GEMM whose + size is proportional to the number of segments (batch size), 1MB workspace is enough for most cases. + """ + self._workspace_buffer = workspace_buffer + self._wrapper = _kernels.CutlassSegmentGEMMPyTorchWrapper( + self._workspace_buffer + ) + + def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): + r"""Reset the workspace buffer. + + Parameters + ---------- + new_workspace_buffer : torch.Tensor + The new workspace buffer, the device of the new workspace buffer should + be the same as the device of the input tensors. + """ + self._workspace_buffer = new_workspace_buffer + self._wrapper.register_workspace_buffer(new_workspace_buffer) + + def forward( + self, + x: torch.Tensor, + weights: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_lens: Optional[torch.Tensor] = None, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""Forward pass of segment GEMM. + + Compute the matrix multiplication between a batch of input tensor (with variable number of rows, but fixed + number of columns) and a batch of weight tensor with fixed number of rows and columns: + + .. math:: + + y[i] = x[i] \times W[i] + + if :attr:`weight_indices` is provided, we will select the weight tensor based on the indices in the + :attr:`weight_indices` tensor: + + .. math:: + + y[i] = x[i] \times W[weight_indices[i]] + + We use Ragged Tensor to represent the input tensor :attr:`x` and the output tensor :attr:`y`, and each x[i] + is a segment of the concatenated tensor. Please see :ref:`Ragged Tensor tutorial ` for more details. + We use a ``seg_len`` or ``seg_indptr`` tensor (either would work) to indicate the start and end of each segment, + where the ``seg_indptr`` is the cumulative sum of the ``seg_lens`` tensor (with an additional 0 at the beginning): + + .. math:: + + \text{seg_indptr}[i] = \sum_{j=0}^{i-1} \text{seg_lens}[j], \quad \text{seg_indptr}[0] = 0 + + - If ``seg_lens`` is provided, then :attr:`x` has shape ``(sum(seg_lens), d_in)`` and :attr:`y` has shape + ``(sum(seg_lens), d_out)``, where ``d_in`` is the number of columns of the input tensor and ``d_out`` is the + number of columns of the output tensor. + - If ``seg_indptr`` is provided, then :attr:`x` has shape ``(seg_indptr[-1], d_in)`` and :attr:`y` has shape + ``(seg_indptr[-1], d_out)``. + + Parameters + ---------- + x : torch.Tensor + The input tensor with shape ``(sum(seg_lens), d_in)``. + weights : torch.Tensor + The 3D weight tensor with shape ``(num_weights, d_in, d_out)`` if :attr:`weight_column_major` is ``False``, + or ``(num_weights, d_out, d_in)`` if :attr:`weight_column_major` is ``True``. + batch_size : int + The number of segments. + weight_column_major : bool + Whether the weight tensor is column major. + seg_lens : Optional[torch.Tensor] + The length of each segment, with shape ``(batch_size,)``, expects a 1D tensor of dtype ``torch.int64``. + seg_indptr : Optional[torch.Tensor] + The indptr of the segments, with shape ``(batch_size + 1,)``, expects a 1D tensor of dtype ``torch.int64``. + If this is provided, then :attr:`seg_lens` will be ignored, otherwise ``seg_indptr`` will be computed + internally from :attr:`seg_lens`. + weight_indices : Optional[torch.Tensor] + The indices of the weight tensor to be selected for each segment, with shape ``(batch_size,)``. + Expects a 1D tensor of dtype ``torch.int64``. + If this is provided, then the weight tensor will be selected based on the indices in this tensor. + + Returns + ------- + torch.Tensor + The output tensor with shape ``(sum(seg_lens), d_out)``. + """ + if seg_lens is None and seg_indptr is None: + raise ValueError("Either seg_lens or seg_indptr should be provided.") + if seg_indptr is None: + seg_indptr = get_indptr(seg_lens.to(x)) + if weight_indices is None: + # create an empty CPU tensor as placeholder + weight_indices = torch.empty(0, dtype=torch.int64) + return self._wrapper.forward( + seg_indptr, + weight_indices, + x, + weights, + batch_size, + weight_column_major, + ) diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 664ac879d..c17536524 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -57,3 +57,10 @@ def check_kv_layout(kv_layout: str): def is_float8(x: torch.Tensor): return x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + + +def get_indptr(x: torch.Tensor): + x = x.to(torch.int64) + ret = torch.zeros(x.shape[0] + 1, dtype=x.dtype, device=x.device) + ret[1:] = x.cumsum(0) + return ret diff --git a/python/setup.py b/python/setup.py index ddc035441..fff016b75 100644 --- a/python/setup.py +++ b/python/setup.py @@ -386,10 +386,14 @@ def __init__(self, *args, **kwargs) -> None: "csrc/batch_prefill.cu", "csrc/sampling.cu", "csrc/norm.cu", + "csrc/group_gemm.cu", ] + get_instantiation_cu(), include_dirs=[ str(root.resolve() / "include"), + str( + root.resolve() / "3rdparty" / "cutlass" / "include" + ), # for group gemm ], extra_compile_args={ "cxx": [ diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index a72704dc2..57f280b01 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -43,7 +43,7 @@ def test_batch_prefill_with_paged_kv_cache( causal, kv_layout, pos_encoding_mode, - enable_cuda_graph + enable_cuda_graph, ): q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len @@ -81,23 +81,29 @@ def test_batch_prefill_with_paged_kv_cache( head_dim, page_size, ) - o = wrapper.forward(q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode) + o = wrapper.forward( + q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode + ) else: q_indptr_buffer = torch.empty(batch_size + 1).int().to(0) kv_indptr_buffer = torch.empty(batch_size + 1).int().to(0) kv_indices_buffer = torch.empty(total_num_pages).int().to(0) kv_last_page_len_buffer = torch.empty(batch_size).int().to(0) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, enable_cuda_graph=True, + workspace_buffer, + kv_layout, + enable_cuda_graph=True, qo_indptr_buf=q_indptr_buffer, paged_kv_indptr_buf=kv_indptr_buffer, paged_kv_indices_buf=kv_indices_buffer, - paged_kv_last_page_len_buf=kv_last_page_len_buffer + paged_kv_last_page_len_buf=kv_last_page_len_buffer, ) q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len kv_indptr_warmup = torch.arange(0, batch_size + 1).int() kv_indices_warmup = torch.arange(0, batch_size).int() - kv_last_page_len_warmup = torch.full((batch_size,), page_size, dtype=torch.int32) + kv_last_page_len_warmup = torch.full( + (batch_size,), page_size, dtype=torch.int32 + ) wrapper.begin_forward( q_indptr_warmup, kv_indptr_warmup, @@ -113,9 +119,7 @@ def test_batch_prefill_with_paged_kv_cache( s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): - o = wrapper.forward( - q, kv_data, pos_encoding_mode=pos_encoding_mode - ) + o = wrapper.forward(q, kv_data, pos_encoding_mode=pos_encoding_mode) torch.cuda.current_stream().wait_stream(s) # capture g = torch.cuda.CUDAGraph() @@ -148,7 +152,9 @@ def test_batch_prefill_with_paged_kv_cache( ( kv_data[kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" - else kv_data[kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], :] + else kv_data[ + kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], : + ] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), @@ -163,7 +169,9 @@ def test_batch_prefill_with_paged_kv_cache( ( kv_data[kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" - else kv_data[kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], :] + else kv_data[ + kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], : + ] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), @@ -381,4 +389,3 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( ) test_batch_prefill_with_ragged_kv_cache(12, 54, 37, 8, 8, 128, True, "NONE") test_batch_prefill_with_ragged_kv_cache_custom_mask(12, 137, 137, 8, 8, 128, "NONE") - diff --git a/python/tests/test_group_gemm.py b/python/tests/test_group_gemm.py new file mode 100644 index 000000000..3fb4fb8eb --- /dev/null +++ b/python/tests/test_group_gemm.py @@ -0,0 +1,107 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import flashinfer +import numpy as np +import torch +import pytest + + +@pytest.mark.parametrize("batch_size", [1, 77, 199]) +@pytest.mark.parametrize("num_rows_per_batch", [3, 10, 99]) +@pytest.mark.parametrize("d_in", [128, 1024, 4096]) +@pytest.mark.parametrize("d_out", [128, 1024, 4096]) +@pytest.mark.parametrize("use_weight_indices", [False, True]) +@pytest.mark.parametrize("column_major", [False, True]) +def test_segment_gemm( + batch_size, + num_rows_per_batch, + d_in, + d_out, + use_weight_indices, + column_major, +): + torch.manual_seed(42) + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + segment_gemm = flashinfer.group_gemm.SegmentGEMMWrapper(workspace_buffer) + x = ( + (torch.randn(batch_size * num_rows_per_batch, d_in) / 10) + .to(0) + .to(torch.float16) + ) + if use_weight_indices: + num_weights = 1024 + if column_major: + weight = ( + (torch.randn(num_weights, d_out, d_in) / 10).to(0).to(torch.float16) + ) + else: + weight = ( + (torch.randn(num_weights, d_in, d_out) / 10).to(0).to(torch.float16) + ) + else: + weight = (torch.randn(batch_size, d_in, d_out) / 10).to(0).to(torch.float16) + y = segment_gemm.forward( + x, + weight, + batch_size, + weight_column_major=column_major, + seg_lens=torch.full((batch_size,), num_rows_per_batch, dtype=torch.int64), + weight_indices=( + (torch.arange(0, batch_size) % num_weights).to(0) + if use_weight_indices + else None + ), + ) + + if use_weight_indices: + for i in range(batch_size): + np.testing.assert_allclose( + y[i * num_rows_per_batch : (i + 1) * num_rows_per_batch].cpu().numpy(), + torch.matmul( + x[i * num_rows_per_batch : (i + 1) * num_rows_per_batch], + ( + weight[i % num_weights].T + if column_major + else weight[i % num_weights] + ), + ) + .cpu() + .numpy(), + rtol=1e-3, + atol=1e-3, + err_msg="assertion failed at batch {}".format(i), + ) + else: + np.testing.assert_allclose( + y.cpu().numpy(), + torch.matmul( + x.view(batch_size, num_rows_per_batch, d_in), + weight.transpose(-1, -2) if column_major else weight, + ) + .view(batch_size * num_rows_per_batch, d_out) + .cpu() + .numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + test_segment_gemm(199, 99, 128, 128, False, False) + test_segment_gemm(199, 99, 128, 128, False, True) + test_segment_gemm(199, 99, 128, 128, True, False) + test_segment_gemm(199, 99, 128, 128, True, True)