From 96c98ecbfa1b9b1a286d150cada764bdcd4de64e Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Mon, 25 Jul 2022 10:20:45 +0200 Subject: [PATCH] MemoryEff attention forward: Properly fuse matmul and enable TensorCores on the second matmul (#368) * Generic backwards * Guard backward to sm75 only * bounds checking for gradV * clang-format * Fused gemm working for Sm80/Sm75 f16/f32 * WIP * Volta TensorOp for f16 * Working on A100 again * SIMT working * Code cleanup 1 * Code cleanup2 * BUGFIX for shared memory limit * Remove code * clang-format * Remove code again * Remove draft of backward * Enforce alignment for fp16 * Fix tests * Fix constraint on seq length when not using tensorcores * Fix alignment requirements for V100/tensorcores * Clang-format * Update xformers/components/attention/csrc/cuda/attention_forward_generic.cu Co-authored-by: Francisco Massa * Address comments from fmassa Co-authored-by: danthe3rd Co-authored-by: danthe3rd Co-authored-by: Francisco Massa --- setup.py | 5 +- tests/test_mem_eff_attention.py | 2 +- .../cutlass/include/cutlass/arch/mma.h | 9 +- .../cutlass/include/cutlass/arch/mma_sm50.h | 78 +- .../warp/mma_tensor_op_tile_iterator_sm70.h | 2 +- .../benchmarks/benchmark_mem_eff_attention.py | 4 + .../csrc/cuda/attention_forward_generic.cu | 581 +++---- .../cuda/attention_scaling_coefs_updater.h | 2 +- .../csrc/cuda/epilogue_rescale_output.h | 707 ++++++++ .../attention/csrc/cuda/find_default_mma.h | 245 +-- .../attention/csrc/cuda/mma_from_smem.h | 1509 +++++++++++++++++ xformers/ops.py | 52 +- 12 files changed, 2656 insertions(+), 540 deletions(-) create mode 100644 xformers/components/attention/csrc/cuda/epilogue_rescale_output.h create mode 100644 xformers/components/attention/csrc/cuda/mma_from_smem.h diff --git a/setup.py b/setup.py index 3cfd68060b..a084b6d2e2 100644 --- a/setup.py +++ b/setup.py @@ -172,7 +172,10 @@ def get_extensions(): if cuda_version >= 1102: nvcc_flags += ["--threads", "4"] extra_compile_args["nvcc"] = nvcc_flags - if cuda_version >= 1100 and os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") == "0": + if ( + cuda_version >= 1100 + and os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") == "0" + ): ext_modules += get_flash_attention_extensions( cuda_version=cuda_version, extra_compile_args=extra_compile_args ) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 683b26a6a9..08097eaffc 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -112,7 +112,7 @@ def test_memory_efficient_attention( ).float() ref = ref_attention(query, key, value, attn_bias) - assert_allclose(out, ref, atol=op.FORWARD_ERROR_ATOL) + assert_allclose(out, ref, atol=op.FORWARD_ERROR_ATOL[dtype]) @pytest.mark.parametrize("k_len", [5, 6, 32]) diff --git a/third_party/cutlass/include/cutlass/arch/mma.h b/third_party/cutlass/include/cutlass/arch/mma.h index e79a4099ff..ce3e02f365 100644 --- a/third_party/cutlass/include/cutlass/arch/mma.h +++ b/third_party/cutlass/include/cutlass/arch/mma.h @@ -143,16 +143,17 @@ template < /// Layout of B matrix (concept: MatrixLayout) typename LayoutB, /// Element type of C matrix - typename ElementC, + typename ElementC_, /// Layout of C matrix (concept: MatrixLayout) typename LayoutC, /// Inner product operator typename Operator_ > -struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> { +struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> { using Shape = gemm::GemmShape<1, 1, 1>; using Operator = Operator_; + using ElementC = ElementC_; CUTLASS_HOST_DEVICE void operator()( @@ -218,8 +219,8 @@ struct SparseMma; #include "cutlass/arch/mma_sm50.h" #include "cutlass/arch/mma_sm60.h" #include "cutlass/arch/mma_sm61.h" -#include "cutlass/arch/mma_sm70.h" -#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm70.h" +#include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/mma_sm80.h" #include "cutlass/arch/mma_sparse_sm80.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/third_party/cutlass/include/cutlass/arch/mma_sm50.h b/third_party/cutlass/include/cutlass/arch/mma_sm50.h index 96977c4134..1b071a5281 100644 --- a/third_party/cutlass/include/cutlass/arch/mma_sm50.h +++ b/third_party/cutlass/include/cutlass/arch/mma_sm50.h @@ -62,6 +62,7 @@ struct Mma, 1, float, LayoutA, float, LayoutB, float, L using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAdd; + using ElementC = float; CUTLASS_HOST_DEVICE void operator()( @@ -144,12 +145,12 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; @@ -184,12 +185,12 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - complex, - LayoutA, - float, - LayoutB, - complex, - LayoutC, + complex, + LayoutA, + float, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; @@ -222,12 +223,12 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - float, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, + float, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; @@ -260,12 +261,12 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - complex, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; @@ -298,12 +299,12 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - complex, - LayoutA, - double, - LayoutB, - complex, - LayoutC, + complex, + LayoutA, + double, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; @@ -334,12 +335,12 @@ template < struct Mma< gemm::GemmShape<1, 1, 1>, 1, - double, - LayoutA, - complex, - LayoutB, - complex, - LayoutC, + double, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; @@ -373,7 +374,8 @@ struct Mma, 1, half_t, LayoutA, half_t, LayoutB, float, using Shape = gemm::GemmShape<1, 1, 1>; using Operator = OpMultiplyAdd; - + using ElementC = float; + CUTLASS_HOST_DEVICE void operator()( Array &d, @@ -412,7 +414,7 @@ struct Mma, 1, Quaternion, LayoutA, Quaternion op; d[0] = op(a[0], b[0], c[0]); } - + }; } diff --git a/third_party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/third_party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h index 5c824b26ea..f136459a6e 100644 --- a/third_party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h +++ b/third_party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h @@ -1205,7 +1205,7 @@ class MmaVoltaTensorOpAccumulatorTileIterator { InterleavedTile::kColumn / InstructionShape::kN>; }; -private: +public: // Assume accumulator tile is multipile interleaved 32x32 tile. static int const kElementsPerPartial = 4; diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index 7d706b2660..8a61a093b0 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -15,6 +15,8 @@ import xformers.ops +torch.backends.cuda.matmul.allow_tf32 = False + def create_attn_bias( bias_type, batch_size: int, q_len: int, kv_len: int, device, dtype @@ -89,6 +91,7 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype): k=K, attn_bias_type=attn_bias_type, has_dropout=False, + kv_len=M, ) try: op = dispatch.op if FORCE_OP is None else FORCE_OP @@ -162,6 +165,7 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype): k=K, attn_bias_type=attn_bias_type, has_dropout=False, + kv_len=M, ) try: op = dispatch.op if FORCE_OP is None else FORCE_OP diff --git a/xformers/components/attention/csrc/cuda/attention_forward_generic.cu b/xformers/components/attention/csrc/cuda/attention_forward_generic.cu index 6b84c2a4d7..1888d314f5 100644 --- a/xformers/components/attention/csrc/cuda/attention_forward_generic.cu +++ b/xformers/components/attention/csrc/cuda/attention_forward_generic.cu @@ -14,33 +14,50 @@ #include "cutlass/numeric_types.h" #include "attention_scaling_coefs_updater.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" #include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass/gemm/threadblock/default_mma_core_simt.h" #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" #include "cutlass/matrix_shape.h" #include "cutlass/platform/platform.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "epilogue_rescale_output.h" #include "find_default_mma.h" +#include "mma_from_smem.h" #include -// XXX: Maybe CUDA will wake up one day and provide this +//////////////////////////////////////////////////////////////////////////////// +// Some helper functions +//////////////////////////////////////////////////////////////////////////////// + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + F(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + F(); \ + } \ + } + template -struct math; +struct TypeTraits; template <> -struct math { +struct TypeTraits { using scalar_t = cutlass::half_t; using torch_dtype = half; static constexpr at::ScalarType kAtScalarType = at::ScalarType::Half; - static __device__ __forceinline__ cutlass::half_t exp( - cutlass::half_t const& h) { - return cutlass::half_t(hexp(h.to_half())); - } template static __host__ at::PackedTensorAccessor32 packed_accessor( at::Tensor const& tensor) { @@ -50,24 +67,21 @@ struct math { tensor.strides().data()); } }; -constexpr at::ScalarType math::kAtScalarType; +constexpr at::ScalarType TypeTraits::kAtScalarType; template <> -struct math { +struct TypeTraits { using scalar_t = float; using torch_dtype = float; static constexpr at::ScalarType kAtScalarType = at::ScalarType::Float; - static __device__ __forceinline__ float exp(float const& h) { - return expf(h); - } template static __host__ at::PackedTensorAccessor32 packed_accessor( at::Tensor const& tensor) { return tensor.packed_accessor32(); } }; -constexpr at::ScalarType math::kAtScalarType; +constexpr at::ScalarType TypeTraits::kAtScalarType; namespace { template @@ -75,48 +89,69 @@ constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { return (n + m - 1) / m; } +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below template -struct GemmTypeQK { - // Default GEMM with simt +struct DefaultGemmType { static constexpr int ThreadK = 8; static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; using OpClass = cutlass::arch::OpClassSimt; using Operator = cutlass::arch::OpMultiplyAdd; }; -// Using GEMM with TensorCores when available +// Specialization for tensorcores with f32 template -struct GemmTypeQK< +struct DefaultGemmType< ArchTag, - float, // scalar_t_ + float, typename std::enable_if= 80>::type> { static constexpr int ThreadK = 32; static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; using OpClass = cutlass::arch::OpClassTensorOp; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; using Operator = cutlass::arch::OpMultiplyAddFastF32; }; +// Specialization for tensorcores with f16 - Sm75+ template -struct GemmTypeQK< +struct DefaultGemmType< ArchTag, - cutlass::half_t, // scalar_t_ - typename std::enable_if= 70>::type> { + cutlass::half_t, + typename std::enable_if= 75>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { static constexpr int ThreadK = 32; static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; using OpClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = typename std::conditional< - ArchTag::kMinComputeCapability >= 75, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::gemm::GemmShape<8, 8, 4>>::type; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; using Operator = cutlass::arch::OpMultiplyAdd; }; template < + // The datatype of Q/K/V typename scalar_t_, + // Intermediate accumulation type (including softmax) typename accum_t_, + // Output type (only float tested so far) typename output_t_, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel bool isAligned_> struct AttentionKernelInfo { using scalar_t = scalar_t_; @@ -147,7 +182,14 @@ struct AttentionKernel { static constexpr int64_t kWarpSize = KernelInfo::kWarpSize; struct MM0 { - using GemmType = GemmTypeQK; + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; using OpClass = typename GemmType::OpClass; using DefaultConfig = @@ -156,13 +198,13 @@ struct AttentionKernel { ArchTag, scalar_t, scalar_t, - accum_t, // ElementC + scalar_t, // ElementC accum_t // ElementAccumulator >; static constexpr int64_t kAlignmentA = - kIsAligned ? DefaultConfig::kAlignmentA : 1; + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; static constexpr int64_t kAlignmentB = - kIsAligned ? DefaultConfig::kAlignmentB : 1; + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; using ThreadblockShape = cutlass::gemm::GemmShape< kQueriesPerBlock, kNumWarpsPerBlock * kWarpSize, @@ -183,7 +225,8 @@ struct AttentionKernel { ThreadblockShape, // ThreadblockShape WarpShape, // WarpShape typename GemmType::InstructionShape, // InstructionShape - 2, // Should use `DefaultConfig::kStages`, but that uses too much smem + DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that + // uses too much smem typename GemmType::Operator // Operator >::DefaultMma; using MmaCore = typename DefaultMma::MmaCore; @@ -195,65 +238,97 @@ struct AttentionKernel { accum_t, kQueriesPerBlock, kWarpSize>::Updater; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; }; struct MM1 { - using ThreadblockShape = cutlass::gemm:: - GemmShape; - using WarpShape = cutlass::gemm::GemmShape; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - - // default_mma_core_simt.h - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - ThreadblockShape, // ThreadblockShape, - WarpShape, // WarpShape, - InstructionShape, // InstructionShape, - accum_t, // ElementA, + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + output_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int64_t kAlignmentA = + DefaultConfig::kAlignmentA; // from smem + static constexpr int64_t kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm::GemmShape< + kQueriesPerBlock, + kNumWarpsPerBlock * kWarpSize, + GemmType::ThreadK>; + using WarpShape = + cutlass::gemm::GemmShape; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, scalar_t, // ElementB, - cutlass::layout::RowMajor, // LayoutB, - accum_t, // ElementC, + LayoutB, // LayoutB, + kAlignmentB, + output_t, cutlass::layout::RowMajor, // LayoutC, - cutlass::arch::OpClassSimt, - 2, // Stages, - cutlass::arch::OpMultiplyAdd // Operator, - >; - - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - typename MmaCore::ElementA, - typename MmaCore::LayoutA, - 1, - typename MmaCore::IteratorThreadMapA, - MmaCore::IteratorThreadMapA::kElementsPerAccess, // AccessSize - false, // Gather - false // LoadFromGlobalMemoryOnly - >; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - typename MmaCore::ElementB, - typename MmaCore::LayoutB, - 0, - typename MmaCore::IteratorThreadMapB>; - - using Mma = cutlass::gemm::threadblock::MmaPipelined< - typename MmaCore::Shape, - IteratorA, - typename MmaCore::SmemIteratorA, - IteratorB, - typename MmaCore::SmemIteratorB, - typename MmaCore::ElementC, - typename MmaCore::LayoutC, - typename MmaCore::MmaPolicy>; + accum_t, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MM0::AccumulatorSharedStorage>; + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = typename DefaultEpilogue::OutputTileIterator; + + struct SharedStorageMM1 { + union { + // Storing parts of `V` during the matmul + typename Mma::SharedStorage mm; + // Used by the Epilogue (so we can reuse the same memory space) + typename DefaultEpilogue::SharedStorage epilogue; + }; + }; static __device__ void compute_dot_product_att_value( - typename Mma::SharedStorage& shared_storage, + SharedStorageMM1& shared_storage_mm, + typename MM0::AccumulatorSharedStorage& shared_storage_si, int32_t const& iter_key_start, at::TensorAccessor& value, cutlass::Array const& m_prime, - accum_t si[kQueriesPerBlock][kNumWarpsPerBlock * kWarpSize], + cutlass::Array const& s_prime, + bool isLast, at::TensorAccessor& output) { cutlass::gemm::GemmCoord problem_size( @@ -264,40 +339,27 @@ struct AttentionKernel { int32_t(kNumWarpsPerBlock * kWarpSize), value.size(0) - iter_key_start) // K ); - typename IteratorA::Params params_A(kNumWarpsPerBlock * kWarpSize); - typename IteratorA::TensorRef ref_A( - &si[0][0], kNumWarpsPerBlock * kWarpSize); - typename IteratorB::Params params_B( - typename MmaCore::LayoutB(value.stride(0))); + typename IteratorB::Params params_B(LayoutB(value.stride(0))); typename IteratorB::TensorRef ref_B( &value[iter_key_start][0], value.stride(0)); static_assert( - MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * - MmaCore::WarpCount::kK == - kNumWarpsPerBlock); + WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock); const int64_t nBlockN = ceil_div((int64_t)problem_size.n(), int64_t(ThreadblockShape::kN)); for (int blockN = 0; blockN < nBlockN; ++blockN) { - // Compute threadblock location + /* + Run the matmul `attn @ V` for a block of attn and V. + `attn` is read from shared memory (in `shared_storage_si`) + `V` is read from global memory (with iterator_B) + */ cutlass::gemm::GemmCoord tb_tile_offset = {0, blockN, 0}; - cutlass::MatrixCoord tb_offset_A{ - tb_tile_offset.m() * Mma::Shape::kM, tb_tile_offset.k()}; - cutlass::MatrixCoord tb_offset_B{ tb_tile_offset.k(), tb_tile_offset.n() * Mma::Shape::kN}; - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params_A, - ref_A.data(), - {problem_size.m(), problem_size.k()}, - thread_id(), - tb_offset_A); - typename Mma::IteratorB iterator_B( params_B, ref_B.data(), @@ -305,74 +367,80 @@ struct AttentionKernel { thread_id(), tb_offset_B); - // Construct thread-scoped matrix multiply - Mma mma(shared_storage, thread_id(), warp_id(), lane_id()); - - auto iterator_C_offset_m = (tb_tile_offset.m() * Mma::WarpCount::kM) + - (warp_id() % Mma::WarpCount::kM); - auto iterator_C_offset_n = (tb_tile_offset.n() * Mma::WarpCount::kN) + - (warp_id() / Mma::WarpCount::kM); - using LaneMmaShape = typename Mma::Policy; - typename Mma::Operator::IteratorC::Policy::LaneLayout lane_layout = - Mma::Operator::IteratorC::Policy::get_lane_layout(); - cutlass::MatrixCoord lane_offset = - lane_layout.inverse(lane_id()) * - cutlass::MatrixCoord( - Mma::Operator::IteratorC::Policy::LaneMmaShape::kM, - Mma::Operator::IteratorC::Policy::LaneMmaShape::kN); - - typename Mma::FragmentC accum, - accum2; // cutlass::Array - // TODO: We could avoid all this mess using cutlass's Epilogue concept I - // think but I got lost in templates and reimplemented everything - - const int32_t thread_offset_m = - Mma::WarpGemm::kM * iterator_C_offset_m + lane_offset.row(); - const int32_t thread_offset_n = - Mma::WarpGemm::kN * iterator_C_offset_n + lane_offset.column(); - output_t* output_ptr = &output[query_start()][0]; - const int32_t output_s0 = output.stride(0); - const int32_t max_m = output.size(0) - query_start(); - const int32_t max_n = output.size(1); - - // Load data already calculated, and rescale it (as the max value for - // the softmax might have changed) Technically, we could do that on - // `accum`, but then we would have to wait for load to finish to start - // the gemm calculations. Let's rather load it in parallel (software - // pipelining) on another register `accum2` + typename Mma::FragmentC accum; accum.clear(); - accum2.clear(); - iterate_on_frag( - accum2, - thread_offset_m, - thread_offset_n, - [&](typename Mma::FragmentC::reference accum_v, - int32_t m, - int32_t n) { - if (m < max_m && n < max_n) { - accum_v = accum_t(output_ptr[m * output_s0 + n]) * m_prime[m]; - } - }); + + Mma mma( + shared_storage_mm.mm, + shared_storage_si, + thread_id(), + warp_id(), + lane_id()); + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); - - // Add discounted `v_prime` (stored in `accum2`) to `accum` (which will - // be stored to `output`) - accum = cutlass::plus()(accum, accum2); - iterate_on_frag( - accum, - thread_offset_m, - thread_offset_n, - [&](typename Mma::FragmentC::reference accum_v, - int32_t const& m, - int32_t const& n) { - if (m < max_m && n < max_n) { - output_ptr[m * output_s0 + n] = output_t(accum_v); - } - }); + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + mma(gemm_k_iterations, accum, iterator_B, accum); + + /* + Epilogue: Store the following into global memory + output <- alpha * accumulator + beta * source + with: + alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) + beta = alpha / m_prime (renormalize the output when the max + changes) source is the current output + */ + OutputTileIterator output_tile_it( + typename OutputTileIterator::Params{output.stride(0)}, + &output[query_start()][0], + {output.size(0) - query_start(), output.size(1)}, + thread_id()); + OutputTileIterator source_tile_it( + typename OutputTileIterator::Params{output.stride(0)}, + &output[query_start()][0], + {output.size(0) - query_start(), output.size(1)}, + thread_id()); + + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&]() { + DISPATCH_BOOL( + isLast, kIsLast, ([&]() { + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + output_t, + DefaultConfig::EpilogueOutputOp::kCount, + typename DefaultConfig::EpilogueOutputOp:: + ElementAccumulator, + typename DefaultConfig::EpilogueOutputOp:: + ElementCompute, + kIsFirst, + kIsLast>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpilogueWithRowId< + typename DefaultEpilogue::Shape, + typename Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename DefaultEpilogue::OutputTileIterator, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true // IterationsUnroll + >; + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage_mm.epilogue, + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, output_tile_it, accum, source_tile_it); + })); + })); } } }; @@ -381,16 +449,20 @@ struct AttentionKernel { static constexpr int64_t kAlignmentK = MM0::kAlignmentB; static constexpr int64_t kAlignmentV = 1; - struct SharedStorageGlobal { - accum_t si[kQueriesPerBlock][kNumWarpsPerBlock * kWarpSize]; + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; cutlass::Array mi; - typename MM1::Mma::SharedStorage mm1; + typename MM1::SharedStorageMM1 mm1; }; - union SharedStorage { - // Shared storage needed by threadblock-scoped matrix multiply-accumulate - typename MM0::Mma::SharedStorage mm0; - SharedStorageGlobal after_mm0; + struct SharedStorage { + cutlass::Array m_prime; + cutlass::Array s_prime; + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; }; static void __device__ attention_kernel( @@ -411,9 +483,10 @@ struct AttentionKernel { int32_t num_queries = query.size(0); int32_t K = key.size(1); - __shared__ cutlass::Array m_prime; - __shared__ cutlass::Array s_prime; - __shared__ SharedStorage shared_storage; + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; auto& si = shared_storage.after_mm0.si; auto& mi = shared_storage.after_mm0.mi; @@ -428,26 +501,30 @@ struct AttentionKernel { iter_key_start += kNumWarpsPerBlock * kWarpSize) { __syncthreads(); // Need to have shared memory initialized, and `m_prime` // updated from end of prev iter - // 1. Compute dot-product into shared memory for each query // also calculates `mi`, and updates `m_prime` / `s_prime` compute_dot_product_qk( iter_key_start, query, key, m_prime, s_prime, shared_storage); __syncthreads(); + bool isLast = + (iter_key_start + kNumWarpsPerBlock * kWarpSize) >= num_keys; - // 4. Partial matmull with the values we have and V + // 4. Partial matmul with the values we have and V // `v* <- v* . exp(m* - mi) + v_i . exp(si - mi)` MM1::compute_dot_product_att_value( shared_storage.after_mm0.mm1, + shared_storage.after_mm0.si, iter_key_start, value, m_prime, - si, + s_prime, + isLast, // 6. Divide by s_prime all of the values on the last + // iteration output); __syncthreads(); // we modify `m_prime` after - // 5. `m_prime` <- `mi` + // 5. `m_prime` <- `mi` (`mi` will be overwritten during MM0) if (warp_id == 0) { static_assert(kQueriesPerBlock == kWarpSize); m_prime[lane_id] = mi[lane_id]; @@ -455,28 +532,6 @@ struct AttentionKernel { __syncthreads(); } - // 6. Divide by s_prime all of the values - const int32_t output_stride0 = output.stride(0); - const int32_t iter_col_last = output.size(1) - lane_id; - int32_t iter_query_last = std::min( - (int32_t)kQueriesPerBlock, - int32_t(num_queries - warp_id - query_start())); - if (iter_col_last > 0 && iter_query_last > 0) { - // &output[query_start()][thread_id] - output_t* output_line_ptr = - output.data() + (query_start() + warp_id) * output_stride0 + lane_id; - for (int32_t q = 0; q < iter_query_last; - q += kNumWarpsPerBlock) { // parallel warps - auto line_s_prime = s_prime[q + warp_id]; - for (int32_t value_col = 0; value_col < iter_col_last; - value_col += kWarpSize) { // parallel lanes - output_line_ptr[value_col] = - output_t(accum_t(output_line_ptr[value_col]) / line_s_prime); - } - output_line_ptr += output_stride0 * kNumWarpsPerBlock; - } - } - // 7. Calculate logsumexp if (logsumexp.size(0) && warp_id == 0) { static_assert(kQueriesPerBlock == kWarpSize); @@ -487,44 +542,6 @@ struct AttentionKernel { } } - template - static void __device__ __forceinline__ iterate_on_frag( - Fragment& frag, - int32_t const& offset_m, - int32_t const& offset_n, - FN callback) { - // TODO: This is quite hacky, and only needed for Simt. For other Mmas, we - // can use epilogue. - using Policy = typename Iterator::Policy; - using Delta = typename Iterator::Delta; - using Iterations = typename Iterator::Iterations; - using Element = typename Iterator::Element; - - CUTLASS_PRAGMA_UNROLL - for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { // 0 - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { - CUTLASS_PRAGMA_UNROLL - for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { - callback( - frag.at( - n + - Policy::LaneMmaShape::kN * - (mma_n + - Iterations::kColumn * - (m + mma_m * Policy::LaneMmaShape::kM))), - offset_m + m + mma_m * Delta::kRow, - offset_n + n + - mma_n * Policy::WarpShape::kColumn * - Policy::LaneMmaShape::kN); - } - } - } - } - } - static __device__ void compute_dot_product_qk( int32_t const& iter_key_start, at::TensorAccessor& query, @@ -537,7 +554,7 @@ struct AttentionKernel { (a) query[query_start:query_end, :] with (b) key[iter_key_start:iter_key_start + kNumWarpsPerBlock * kWarpSize] - and stores that into `si` + and stores that into `shared_storage.si` */ using MmaCore = typename MM0::MmaCore; using Mma = typename MM0::Mma; @@ -606,9 +623,7 @@ struct AttentionKernel { // Compute threadblock-scoped matrix multiply-add mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); - __syncthreads(); - auto& si = shared_storage.after_mm0.si; auto& mi = shared_storage.after_mm0.mi; if (my_warp_id == 0) { static_assert(kQueriesPerBlock == kWarpSize); @@ -626,41 +641,25 @@ struct AttentionKernel { (tb_tile_offset.n() * Mma::WarpCount::kN) + (my_warp_id / Mma::WarpCount::kM)}; // Update `mi` from accum stored in registers - typename MM0::ScalingCoefsUpdater updater; - updater.update( + MM0::ScalingCoefsUpdater::update( accum, mi, m_prime, s_prime, - my_lane_id, - my_warp_id, + lane_id(), + warp_id(), key.size(0) - iter_key_start, iteratorC_tile_offset); - // Output results - typename Mma::Operator::IteratorC iterator_C( - {&si[0][0], kNumWarpsPerBlock * kWarpSize}, my_lane_id); - - iterator_C.add_tile_offset(iteratorC_tile_offset); - iterator_C.store(accum); - } - - static __device__ __forceinline__ accum_t warpMax(accum_t val) { - for (int stride = kWarpSize / 2; stride > 0; stride >>= 1) { - accum_t tmp = - accum_t(__shfl_xor_sync(0xffffffff, val, stride, kWarpSize)); - val = tmp > val ? tmp : val; - } - return val; - } + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; - static __device__ __forceinline__ accum_t warpSum(accum_t val) { - for (int stride = kWarpSize / 2; stride > 0; stride >>= 1) { - accum_t tmp = - accum_t(__shfl_xor_sync(0xffffffff, val, stride, kWarpSize)); - val += tmp; - } - return val; + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); } static __device__ __forceinline__ int8_t lane_id() { @@ -684,6 +683,7 @@ __global__ void __launch_bounds__( AKInfo::kWarpSize* AKInfo::kNumWarpsPerBlock, // minBlocksPerMultiprocessor is optional and specifies the desired minimum // number of resident blocks per multiprocessor + // TODO: We get slightly better performance by *removing* this on A100 12 / AKInfo::kNumWarpsPerBlock) attention_kernel_batched( at::PackedTensorAccessor32 output, @@ -787,7 +787,7 @@ efficient_attention_forward_generic( using ArchTag = cutlass::arch::Sm75; \ func(); \ } else if (computeCapability >= 70) { \ - using ArchTag = cutlass::arch::Sm75; \ + using ArchTag = cutlass::arch::Sm70; \ func(); \ } else if (computeCapability >= 50) { \ using ArchTag = cutlass::arch::Sm50; \ @@ -814,17 +814,6 @@ efficient_attention_forward_generic( } \ } -#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ - { \ - if (BOOL_V) { \ - constexpr bool BOOL_NAME = true; \ - F(); \ - } else { \ - constexpr bool BOOL_NAME = false; \ - F(); \ - } \ - } - DISPATCH_TYPES(([&]() { // Run a more efficient kernel (with `isAligned=True`) if memory is // correctly aligned @@ -836,33 +825,51 @@ efficient_attention_forward_generic( (query.stride(1) % AlignedAK::kAlignmentQ == 0 && key.stride(1) % AlignedAK::kAlignmentK == 0 && value.stride(1) % AlignedAK::kAlignmentV == 0); + // TODO: Should we warn or log somewhere when we use a less efficient + // kernel due to wrong alignment? })); DISPATCH_BOOL( - isAligned, IsAligned, ([&]() { + isAligned, kIsAligned, ([&]() { using AKI = - AttentionKernelInfo; + AttentionKernelInfo; + size_t smem_bytes = 0; DISPATCH_ARCHTAG(([&]() { using AK = AttentionKernel; - TORCH_INTERNAL_ASSERT( + smem_bytes = sizeof(typename AK::SharedStorage); + // Might happen on Sm80/half, where the minimum alignment is 32bits + TORCH_CHECK( query.stride(1) % AK::kAlignmentQ == 0, "query is not correctly aligned"); - TORCH_INTERNAL_ASSERT( + TORCH_CHECK( key.stride(1) % AK::kAlignmentK == 0, "key is not correctly aligned"); - TORCH_INTERNAL_ASSERT( + TORCH_CHECK( value.stride(1) % AK::kAlignmentV == 0, "value is not correctly aligned"); })); - using m = math; + TORCH_INTERNAL_ASSERT(smem_bytes > 0, "No kernel found!?"); res = at::zeros( - {B, M, K}, query.options().dtype(math::kAtScalarType)); + {B, M, K}, + query.options().dtype(TypeTraits::kAtScalarType)); dim3 grid(AKI::kNumBlocksX, AKI::getNumBlocksY(M), B); dim3 block(AKI::kWarpSize, AKI::kNumWarpsPerBlock, 1); - attention_kernel_batched<<>>( - math::packed_accessor<3>(res), + constexpr auto kernel_fn = attention_kernel_batched; + if (smem_bytes > 48000) { + TORCH_INTERNAL_ASSERT( + computeCapability >= 70, + "This kernel requires too much shared memory on this machine!"); + cudaFuncSetAttribute( + kernel_fn, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + } + + using m = TypeTraits; + kernel_fn<<>>( + TypeTraits::packed_accessor<3>(res), logsumexp.packed_accessor32(), m::packed_accessor<3>(query), m::packed_accessor<3>(key), diff --git a/xformers/components/attention/csrc/cuda/attention_scaling_coefs_updater.h b/xformers/components/attention/csrc/cuda/attention_scaling_coefs_updater.h index ae5d1df11e..aa7924de04 100644 --- a/xformers/components/attention/csrc/cuda/attention_scaling_coefs_updater.h +++ b/xformers/components/attention/csrc/cuda/attention_scaling_coefs_updater.h @@ -476,4 +476,4 @@ struct DefaultAttentionScalingCoefsUpdater< accum_t, kQueriesPerBlock, kWarpSize>; -}; +}; \ No newline at end of file diff --git a/xformers/components/attention/csrc/cuda/epilogue_rescale_output.h b/xformers/components/attention/csrc/cuda/epilogue_rescale_output.h new file mode 100644 index 0000000000..5313c74dee --- /dev/null +++ b/xformers/components/attention/csrc/cuda/epilogue_rescale_output.h @@ -0,0 +1,707 @@ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of: + cutlass/epilogue/thread/linear_combination.h + (MemoryEfficientAttentionNormalize) cutlass/epilogue/threadblock/epilogue.h + (EpilogueWithRowId) With a few modifications so that: (1) The Epilogue passes + the row id to the OutputOp (MemoryEfficientAttentionNormalize here) Note that + in general the fragment passed to the OutputOp could span multiple rows but it + does not happen with the configurations we have :) (2) + `MemoryEfficientAttentionNormalize` takes the `s_prime` and `m_prime` vectors + (rather than scalars in `LinearCombination`) and renormalizes the output +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_ = Array, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + + private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize( + FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return !isFirst; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + int row, + FragmentAccumulator const& accumulator, + FragmentOutput const& source) const { + assert(!isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) + const { + assert(isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + + intermediate = mul_accumulator( + alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: + ///< gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output + ///< tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting + ///< accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing + ///< accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading + ///< from SMEM + typename OutputOp_, ///< Output operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank + ///< conflicts (concept: MatrixShape) + int FragmentsPerPartition = + 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is + ///< large + (!IsEpilogueFunctorHeavy::value)> +class EpilogueWithRowId : public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + public: + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array< + typename WarpTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert( + OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert( + !(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + public: + /// Constructor + CUTLASS_DEVICE + EpilogueWithRowId( + typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator + source_iterator) { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_( + output_op, destination_iterator, accumulators, source_iterator); + } + } + + private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper( + iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert( + kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators ///< Complete warp-level accumulator tile + ) { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll( \ + IterationsUnroll \ + ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator + source_iterator ///< Threadblock tile coordinate in GEMM (in units of + ///< threadblock tiles) + ) { + typename OutputTileIterator::Fragment source_fragment; + + source_fragment.clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + // + // Load the source + // + + source_iterator.load(source_fragment); + ++source_iterator; + + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_needed< + cutlass::make_index_sequence>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileIterator::Fragment const& source_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + OutputAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + int row = + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess); + // Call the output operator + output_frag_ptr[i] = + output_op(row, compute_frag_ptr[i], source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + int row = + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess); + // Call the output operator + output_frag_ptr[i] = output_op(row, compute_frag_ptr[i]); + } + } + + constexpr int getRowOffset(int i) { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { + return row_offset; + } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/xformers/components/attention/csrc/cuda/find_default_mma.h b/xformers/components/attention/csrc/cuda/find_default_mma.h index 87956ecfe4..a9f4aae920 100644 --- a/xformers/components/attention/csrc/cuda/find_default_mma.h +++ b/xformers/components/attention/csrc/cuda/find_default_mma.h @@ -52,12 +52,11 @@ template < int Stages, /// Operation perfomed by GEMM typename Operator, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> + typename Enable_ = void> struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = + SharedMemoryClearOption::kNone; using DefaultMma = cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, @@ -80,227 +79,77 @@ struct FindDefaultMma { /// Specialization for sm80 / FastF32 / multistage with kStages=2 template < + typename ElementA_, /// Layout type for A matrix operand typename LayoutA_, /// Access granularity of A matrix in units of elements int kAlignmentA, + typename ElementB_, /// Layout type for B matrix operand typename LayoutB_, /// Access granularity of B matrix in units of elements int kAlignmentB, + typename ElementAccumulator, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape> + typename InstructionShape, + int kStages, + typename Operator> struct FindDefaultMma< - float, + ElementA_, LayoutA_, kAlignmentA, - float, + ElementB_, LayoutB_, kAlignmentB, - float, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, - 2, - arch::OpMultiplyAddFastF32> { - struct DefaultMma { - static constexpr int kStages = 2; - static SharedMemoryClearOption constexpr SharedMemoryClear = - SharedMemoryClearOption::kNone; - using ElementA = float; - using ElementB = float; - using ElementAccumulator = float; - using LayoutC = layout::RowMajor; - using Operator = arch::OpMultiplyAddFastF32; - static constexpr bool GatherA = false; - static constexpr bool GatherB = false; - - static_assert( - std::is_same::value || - std::is_same>::value, - "simt epilogue must be row major"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // In theory we should do the following, but it would match the template for - // MmaPipelined - and we want MmaMultiStage! - // using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< - // ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - // ElementB, LayoutB, ElementAccumulator, LayoutC, - // arch::OpClassTensorOp, Stages, Operator, false, - // CacheOpA, CacheOpB>; - struct MmaCore { - using LayoutA = LayoutA_; - using LayoutB = LayoutB_; - using Shape = ThreadblockShape; - using ElementC = ElementAccumulator; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK>; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - /// Size of a threadblock-scoped access - static int const kAccessSizeInBits = 128; - - // Warp thread arrangement - static int const kWarpThreadArrangementContiguousA = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedA = - kWarpSize / kWarpThreadArrangementContiguousA; - - static int const kWarpThreadArrangementContiguousB = - Shape::kK / (kAccessSizeInBits / sizeof_bits::value); - - static int const kWarpThreadArrangementStridedB = - kWarpSize / kWarpThreadArrangementContiguousB; - - // - // Shared memory layouts - // - - using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, - Shape::kK>; - - // Shared memory layout - using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, - Shape::kK>; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape< - kWarpThreadArrangementContiguousA, - kWarpThreadArrangementStridedA>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 0, - IteratorThreadMapA>; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - kThreads, - layout::PitchLinearShape< - kWarpThreadArrangementContiguousB, - kWarpThreadArrangementStridedB>, - kAccessSizeInBits / sizeof_bits::value>; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 1, - IteratorThreadMapB>; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< - WarpShape, - InstructionShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - Operator, - WarpCount::kK>::Type; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - MmaTensorOp, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK>; - }; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, - LayoutA_, - 1, - ThreadMapA, - AccessTypeA, - GatherA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = - cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, - LayoutB_, - 0, - ThreadMapB, - AccessTypeB, - GatherB>; - + kStages, + Operator, + typename std::enable_if<(kAlignmentA > 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 3, + Operator>; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< - typename MmaCore::Shape, - IteratorA, - typename MmaCore::SmemIteratorA, - MmaCore::kCacheOpA, - IteratorB, - typename MmaCore::SmemIteratorB, - MmaCore::kCacheOpB, + typename MmaCore_::Shape, + typename DefaultMma_::IteratorA, + typename MmaCore_::SmemIteratorA, + MmaCore_::kCacheOpA, + typename DefaultMma_::IteratorB, + typename MmaCore_::SmemIteratorB, + MmaCore_::kCacheOpB, ElementAccumulator, LayoutC, - typename MmaCore::MmaPolicy, - kStages, - SharedMemoryClear>; + typename MmaCore_::MmaPolicy, + kStages>; }; }; diff --git a/xformers/components/attention/csrc/cuda/mma_from_smem.h b/xformers/components/attention/csrc/cuda/mma_from_smem.h new file mode 100644 index 0000000000..7281c17ec5 --- /dev/null +++ b/xformers/components/attention/csrc/cuda/mma_from_smem.h @@ -0,0 +1,1509 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template < + typename Shape_, + typename Element_, + typename Layout_, + typename Padding_> +class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = cutlass:: + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + + public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA, + // Accumulator type + typename AccumulatorSharedStorage, + // END smem + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory + : public MmaBaseFromSharedMemory { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma( + accum, + warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + // Accumulator type + typename AccumulatorSharedStorage, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory + : public MmaBaseFromSharedMemory { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLDGSTSIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / + Base::kWarpGemmIterations1; + }; + + private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_( + accumulator_shared_storage.accum_ref(), + lane_idx), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1& iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index( + group_start_B1 * IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform( + warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + // Load warp-level tile from accumulator fragment + // skip warp tile loading for the last kgroup + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma1( + tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +namespace { +template +struct AssertIsSame { + static_assert(std::is_same::value); + using CHECK = bool; +}; +} // namespace + +template < + typename WarpShape, + typename InstructionShape, + typename RegularWarpIterator, + typename Policy> +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 16, 4>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<1, 1, 1>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema + using WarpIterator = RegularWarpIterator; +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + typename AccumulatorSharedStorage_> +struct DefaultMmaFromSharedMemory< + MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>, + AccumulatorSharedStorage_> { + static constexpr int kWarpSize = 32; + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + using RegularMma = MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< + WarpShape, + InstructionShape, + typename RegularMma::Operator::IteratorA, + Policy_>::WarpIterator; + + using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< + Shape_, + WarpIteratorA, + AccumulatorSharedStorage_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_>; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + typename AccumulatorSharedStorage_> +struct DefaultMmaFromSharedMemory< + MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>, + AccumulatorSharedStorage_> { + static constexpr int kWarpSize = 32; + + using RegularMma = MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< + WarpShape, + InstructionShape, + typename RegularMma::Operator::IteratorA, + Policy_>::WarpIterator; + + using Mma = + typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + AccumulatorSharedStorage_, + IteratorB_, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename IteratorC, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = + cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + scalar_t, // accum_t, + SmemAccumulatorLayout>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + typename SmemIteratorD0::Element, + typename SmemIteratorD0::TensorLayout, + typename SmemIteratorD0::Padding>; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator - not used + OutputOpNoOp>; + + static void __device__ accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< + WarpShape, + cutlass::gemm::GemmShape<32, 32, 4>, + scalar_t, + SmemAccumulatorLayout>; + + // // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< + 16, + 32>, // typename SmemIteratorD0::TensorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + + static void __device__ accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + using OutputLayout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using EleShapePerPatial = typename IteratorC::EleShapePerPatial; + using QuadShapePerPatialMma = typename IteratorC::QuadShapePerPatialMma; + using Element = accum_t; + + auto constexpr kElementsPerMma = IteratorC::kElementsPerMma; + auto constexpr kElementsPerPartial = IteratorC::kElementsPerPartial; + auto constexpr kAccumulatorPatials = IteratorC::kAccumulatorPatials; + + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {SmemIteratorD0::Shape::kRow, SmemIteratorD0::Shape::kColumn})); + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n; + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int r = (accum_m + lane_offset.row()); + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + } + } + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template < + typename Operator, + typename OperatorPolicy, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + + using SmemAccumulatorLayout = cutlass::layout::ColumnMajor; + using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorSimt< + WarpShape, + typename Operator::ArchMmaOperator, + scalar_t, + cutlass::layout::RowMajor, // XXX: only supports rowmajor ... + OperatorPolicy>; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::ColumnMajor, + cutlass::MatrixShape<0, 0> // Padding + >; + + static void __device__ accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + using TensorRef = + typename cutlass::TensorRef; + + TensorRef ref_(shared_storage.accum_ref()); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {SmemIteratorD0::Shape::kRow, SmemIteratorD0::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = + Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + + m; + int c = mma_n * Delta::kColumn + n; + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/xformers/ops.py b/xformers/ops.py index efbc00a094..13821d11bb 100644 --- a/xformers/ops.py +++ b/xformers/ops.py @@ -6,7 +6,7 @@ import math from dataclasses import dataclass -from typing import Any, List, Optional, Set, Type, Union +from typing import Any, List, Mapping, Optional, Set, Type, Union import torch @@ -117,11 +117,15 @@ class AttentionOpBase(torch.autograd.Function): """ FORWARD_OPERATOR: Any - FORWARD_ERROR_ATOL: float = 2e-4 + FORWARD_ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 2e-4, + torch.half: 2e-3, + torch.bfloat16: 2e-3, + } SUPPORTED_DEVICES: Set[str] SUPPORTED_DTYPES: Set[torch.dtype] SUPPORTED_MAX_K: float - SUPPORTS_attn_bias_type: Set[Any] = {type(None)} + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)} SUPPORTS_DROPOUT: bool NAME: str @@ -168,7 +172,7 @@ def supports(cls, d: "AttentionOpDispatch") -> bool: return False if d.k > cls.SUPPORTED_MAX_K: return False - if d.attn_bias_type not in cls.SUPPORTS_attn_bias_type: + if d.attn_bias_type not in cls.SUPPORTED_ATTN_BIAS_TYPES: return False if d.has_dropout and not cls.SUPPORTS_DROPOUT: return False @@ -180,7 +184,7 @@ class MemoryEfficientAttentionOp(AttentionOpBase): SUPPORTED_DEVICES = {"cuda", "cpu"} SUPPORTED_DTYPES = {torch.float} SUPPORTED_MAX_K: float = 32 - SUPPORTS_attn_bias_type: Set[Any] = {type(None), torch.Tensor} + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None), torch.Tensor} SUPPORTS_DROPOUT = True NAME = "small_k" @@ -201,10 +205,35 @@ class MemoryEfficientAttentionGenericForwardOp(AttentionOpBase): SUPPORTED_DEVICES = {"cuda"} SUPPORTED_DTYPES = {torch.float, torch.half} SUPPORTED_MAX_K = math.inf - SUPPORTS_attn_bias_type: Set[Any] = {type(None)} + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)} SUPPORTS_DROPOUT = False NAME = "fwd_gen" + @classmethod + def uses_tensorcores(cls, d: "AttentionOpDispatch", is_half: bool) -> bool: + sm_major = torch.cuda.get_device_capability(d.device)[0] + if sm_major >= 8: + return True + if sm_major >= 7: + return is_half + return False + + @classmethod + def supports(cls, d: "AttentionOpDispatch") -> bool: + if not super(MemoryEfficientAttentionGenericForwardOp, cls).supports(d): + return False + is_sm80 = torch.cuda.get_device_capability(d.device)[0] >= 8 + bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[d.dtype] + uses_tensorcores = cls.uses_tensorcores(d, bits_per_scalar == 16) + if is_sm80 and d.k % 4 != 0: + return False + if uses_tensorcores and d.k % (64 / bits_per_scalar) != 0: + return False + warp_shape = 32 if uses_tensorcores else 8 + if d.kv_len > warp_shape and d.kv_len % warp_shape != 0: + return False + return True + @classmethod def backward(cls, ctx, grad): query, key, value, lse, attn_bias, out = ctx.saved_tensors @@ -234,11 +263,14 @@ class MemoryEfficientAttentionFlashAttentionOp(AttentionOpBase): """ FORWARD_OPERATOR = None - FORWARD_ERROR_ATOL = 5e-2 + FORWARD_ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.half: 5e-2, + torch.bfloat16: 5e-2, + } SUPPORTED_DEVICES = {"cuda"} SUPPORTED_DTYPES = {torch.half, torch.bfloat16} SUPPORTED_MAX_K = 128 - SUPPORTS_attn_bias_type: Set[Any] = {type(None), LowerTriangularMask} + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None), LowerTriangularMask} SUPPORTS_DROPOUT = False NAME = "flshatt" @@ -469,13 +501,14 @@ class AttentionOpDispatch: k: int has_dropout: bool attn_bias_type: Any + kv_len: int @property def op(self) -> Type[AttentionOpBase]: priority_list_ops: List[Type[AttentionOpBase]] = [ MemoryEfficientAttentionFlashAttentionOp, - MemoryEfficientAttentionOp, MemoryEfficientAttentionGenericForwardOp, + MemoryEfficientAttentionOp, ] for op in priority_list_ops: if op.supports(self): @@ -497,6 +530,7 @@ def from_arguments( k=query.shape[-1], has_dropout=p > 0.0, attn_bias_type=type(attn_bias), + kv_len=value.shape[-2], )