diff --git a/CMakeLists.txt b/CMakeLists.txt index ad6736c47f459..aa15b632cdd3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -179,9 +179,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/custom_all_reduce.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") # # The CUTLASS kernels for Hopper require sm90a to be enabled. @@ -189,7 +189,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # That adds an extra 17MB to compiled binary, so instead we selectively enable it. if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) set_source_files_properties( - "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 6de56f618700d..182105f0b33f2 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -76,11 +76,7 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, scale_b: torch.tensor, out_dtype: torch.dtype) -> torch.tensor: - return ops.cutlass_scaled_mm_dq(a, - b, - scale_a, - scale_b, - out_dtype=out_dtype) + return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) # bench diff --git a/csrc/ops.h b/csrc/ops.h index 0c270a78c331f..9e2e977fa3c2e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -90,9 +90,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); -void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales); +void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales); #endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu similarity index 71% rename from csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu rename to csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 23a8b4070b70e..7651268dc5316 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -29,21 +29,14 @@ using namespace cute; /* - This defines a quantized GEMM operation with dequantized output, similar to - torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for + This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. + Epilogue functions can be defined to post-process the output before it is + written to GPU memory. + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. */ namespace { @@ -83,27 +76,25 @@ struct enable_sm89_to_sm90 : Kernel { } }; -template typename ArchGuard, - typename ElementAB_, typename ElementD_, typename TileShape, - typename WarpShape, typename InstructionShape, int32_t MainLoopStages> -struct cutlass_2x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - - using ElementAcc = - typename std::conditional, int32_t, - float>::type; +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. - using Operator = - typename std::conditional, - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type; + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). - using OutputTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - TileShape, WarpShape, float, 4, 1 /* epilogue stages */ - >; + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue { + private: using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< @@ -123,14 +114,56 @@ struct cutlass_2x_gemm { cutlass::multiplies, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; - using EVTCompute1 = + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using ScaleAArgs = typename ScaleA::Arguments; + using ScaleBArgs = typename ScaleB::Arguments; + + ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; + ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; + + typename EVTCompute0::Arguments evt0_compute_args{b_args}; + + typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args}; + return evt_compute_args; + } +}; + +template typename ArchGuard, + typename ElementAB_, typename ElementD_, + template typename Epilogue_, typename TileShape, + typename WarpShape, typename InstructionShape, int32_t MainLoopStages> +struct cutlass_2x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Operator = + typename std::conditional, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type; + + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + TileShape, WarpShape, float, 4, 1 /* epilogue stages */ + >; + + using Epilogue = Epilogue_; + using EVTCompute = typename Epilogue::EVTCompute; using D = cutlass::epilogue::threadblock::VisitorAuxStore< OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, Stride, Int<0>>>; - using EVTD = cutlass::epilogue::threadblock::Sm80EVT; + using EVTD = cutlass::epilogue::threadblock::Sm80EVT; // clang-format off using RowMajor = typename cutlass::layout::RowMajor; @@ -153,11 +186,10 @@ struct cutlass_2x_gemm { using Op = cutlass::gemm::device::GemmUniversalAdapter; }; -template -void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +template +void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -177,23 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, auto b_ptr = static_cast(b.data_ptr()); auto c_ptr = static_cast(out.data_ptr()); - auto a_scales_ptr = a_scales.data_ptr(); - auto b_scales_ptr = b_scales.data_ptr(); - - using ScaleAArgs = typename Gemm::ScaleA::Arguments; - using ScaleBArgs = typename Gemm::ScaleB::Arguments; - - ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - - typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; - - typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args, - evt0_compute_args}; typename Gemm::D::Arguments d_args{c_ptr, c_stride}; + using Epilogue = typename Gemm::Epilogue; + auto evt_args = + Epilogue::prepare_args(std::forward(epilogue_params)...); + typename Gemm::EVTD::Arguments epilogue_args{ - evt1_compute_args, + evt_args, d_args, }; @@ -229,10 +252,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, } // namespace -void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); @@ -243,23 +266,23 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, - b_scales); + ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>( + out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, - b_scales); + ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>( + out, a, b, a_scales, b_scales); } } -void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); @@ -270,23 +293,23 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, - b_scales); + ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>( + out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, - b_scales); + ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>( + out, a, b, a_scales, b_scales); } } -void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; @@ -298,32 +321,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, - b_scales); + ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>( + out, a, b, a_scales, b_scales); } else { assert(out.dtype() == torch::kFloat16); - return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, - b_scales); + ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>( + out, a, b, a_scales, b_scales); } } else { TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher>( - out, a, b, a_scales, b_scales); + cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape, + InstructionShape, 5>>(out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_scaled_mm_dq_dispatcher>( - out, a, b, a_scales, b_scales); + cutlass::half_t, ScaledEpilogue, TileShape, WarpShape, + InstructionShape, 5>>(out, a, b, a_scales, b_scales); } } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu similarity index 66% rename from csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu rename to csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index a99802153643a..f1a2b73ff962b 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -32,21 +32,14 @@ using namespace cute; /* - This defines a quantized GEMM operation with dequantized output, similar to - torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for + This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. + Epilogue functions can be defined to post-process the output before it is + written to GPU memory. + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. */ namespace { @@ -71,21 +64,25 @@ struct enable_sm90_or_later : Kernel { } }; -template -struct cutlass_3x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue { + private: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< @@ -111,19 +108,53 @@ struct cutlass_3x_gemm { cutlass::multiplies, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; - using EVTCompute1 = + public: + using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using ScaleA_Args = typename ScaleA::Arguments; + using ScaleB_Args = typename ScaleB::Arguments; + + ScaleA_Args a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; + ScaleB_Args b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; + + return ArgumentType{a_args, {b_args}}; + } +}; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Epilogue = Epilogue_; using StrideD = Stride, Int<0>>; using ElementC = void; using StrideC = StrideD; + using EVTCompute = typename Epilogue::EVTCompute; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - EpilogueSchedule, EVTCompute1>::CollectiveOp; + EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); @@ -148,11 +179,10 @@ struct cutlass_3x_gemm { struct GemmKernel : public KernelType {}; }; -template -void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +template +void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -182,19 +212,13 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ - {}, c_ptr, c_stride, c_ptr, c_stride}; + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args}; - using ScaleA_Args = typename Gemm::ScaleA::Arguments; - using ScaleB_Args = typename Gemm::ScaleB::Arguments; - - ScaleA_Args a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - ScaleB_Args b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - - args.epilogue.thread = {a_args, {b_args}}; - // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; @@ -209,7 +233,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(status); } -template +template typename Epilogue, int32_t M> struct sm90_fp8_config { static_assert(std::is_same()); using KernelSchedule = @@ -219,12 +244,13 @@ struct sm90_fp8_config { using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_3x_gemm; + cutlass_3x_gemm; }; -template -struct sm90_fp8_config { +template typename Epilogue> +struct sm90_fp8_config { static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; @@ -233,12 +259,13 @@ struct sm90_fp8_config { using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_3x_gemm; + cutlass_3x_gemm; }; -template -struct sm90_fp8_config { +template typename Epilogue> +struct sm90_fp8_config { static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; @@ -247,30 +274,28 @@ struct sm90_fp8_config { using ClusterShape = Shape<_1, _8, _1>; using Cutlass3xGemm = - cutlass_3x_gemm; + cutlass_3x_gemm; }; } // namespace -template -void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); using Cutlass3xGemmDefault = - typename sm90_fp8_config::Cutlass3xGemm; + typename sm90_fp8_config::Cutlass3xGemm; using Cutlass3xGemmM64 = - typename sm90_fp8_config::Cutlass3xGemm; + typename sm90_fp8_config::Cutlass3xGemm; using Cutlass3xGemmM128 = - typename sm90_fp8_config::Cutlass3xGemm; + typename sm90_fp8_config::Cutlass3xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = @@ -278,23 +303,23 @@ void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out, if (mp2 <= 64) { // m in [1, 64] - return cutlass_scaled_mm_dq_dispatcher( - out, a, b, a_scales, b_scales); + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] - return cutlass_scaled_mm_dq_dispatcher( - out, a, b, a_scales, b_scales); + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); } else { // m in (128, inf) - return cutlass_scaled_mm_dq_dispatcher( - out, a, b, a_scales, b_scales); + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); } } -void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); @@ -308,16 +333,15 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher< - cutlass_3x_gemm>( - out, a, b, a_scales, b_scales); + return cutlass_gemm_caller>(out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_scaled_mm_dq_dispatcher< - cutlass_3x_gemm>( + return cutlass_gemm_caller< + cutlass_3x_gemm>( out, a, b, a_scales, b_scales); } } else { @@ -325,13 +349,13 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_sm90_fp8_dispatch( + return cutlass_gemm_sm90_fp8_dispatch< + cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>( out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_scaled_mm_dq_sm90_fp8_dispatch( + return cutlass_gemm_sm90_fp8_dispatch( out, a, b, a_scales, b_scales); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu similarity index 50% rename from csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu rename to csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 423e64a4932e2..687f8efd8dc00 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -3,31 +3,31 @@ #include #include -void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales); +void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales); +void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales); +void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales); +void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); #endif -void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { int32_t major_capability; int32_t minor_capability; cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, @@ -57,19 +57,19 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, // Guard against compilation issues for sm90 kernels #if defined CUDA_VERSION && CUDA_VERSION >= 12000 - cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); + cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales); #else - cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); + cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales); #endif } else if (version_num == 89) { // Ada Lovelace - cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales); + cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales); } else if (version_num >= 80) { // Ampere - cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); + cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales); } else { // Turing TORCH_CHECK(version_num >= 75); - cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales); + cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales); } } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index df2603544c85a..867bf438937cd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -136,10 +136,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization. ops.def( - "cutlass_scaled_mm_dq(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales) -> ()"); - ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq); + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm); #endif // Quantized GEMM for GPTQ. diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index a9aeeb3a78bf5..e7368fb87b6ae 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -52,7 +52,7 @@ def cutlass_fp8_gemm_helper(m: int, scale_b = (torch.randn( (1, n_b_scales), device=device, dtype=torch.float32) / 10) - out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)).to(out_dtype) @@ -79,7 +79,7 @@ def cutlass_int8_gemm_helper(m: int, scale_b = (torch.randn( (1, n_b_scales), device=device, dtype=torch.float32) / 10) - out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)).to(dtype=out_dtype) @@ -205,11 +205,11 @@ def test_cutlass_subset(): scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - out = ops.cutlass_scaled_mm_dq(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) + out = ops.cutlass_scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) @@ -228,8 +228,8 @@ def __init__(self, b, scale_a, scale_b, out_dtype): self.out_dtype = out_dtype def forward(self, a): - return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b, - self.out_dtype) + return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b, + self.out_dtype) @pytest.mark.parametrize("per_act_token", [True, False]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 955086be132e5..2f84b8bde6b57 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -212,9 +212,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # cutlass -def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - out_dtype: Type[torch.dtype]) -> torch.Tensor: +def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: Type[torch.dtype]) -> torch.Tensor: assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) @@ -222,8 +222,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, n = b.shape[1] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) - + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b) return out diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index 25b707caeef33..9bb7bf4470872 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -81,5 +81,5 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): weight_scale = layer.weight_scale x_q, input_scales = custom_ops.scaled_int8_quant(x) - return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales, - weight_scale, x.dtype) + return custom_ops.cutlass_scaled_mm(x_q, weight.t(), input_scales, + weight_scale, x.dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 7559fc0f95b24..88c15c5c26a11 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -99,5 +99,5 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): # Input quantize x_q, _ = custom_ops.scaled_int8_quant(x, act_scale) - return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, - weight_scale, x.dtype) + return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale, + weight_scale, x.dtype) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0cf2bd927a800..e89fd65813c05 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -261,7 +261,7 @@ def apply(self, qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) # Fused GEMM_DQ - output = ops.cutlass_scaled_mm_dq( + output = ops.cutlass_scaled_mm( qinput, layer.weight, out_dtype=x.dtype,