From 1f37d54b177698100a518df2e96ca6c2f11b7778 Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Mon, 28 Nov 2022 16:39:43 +0000 Subject: [PATCH] bwaccf32: Accumulate in f32 for bw ghstack-source-id: 48369de3f8b94eb3c190ac2b0a1b3ddf6003e5ff Pull Request resolved: https://github.com/facebookresearch/xformers/pull/467 --- .../attention_backward_generic.cu | 9 + .../mem_eff_attention/gemm_kernel_utils.h | 5 + .../cuda/mem_eff_attention/kernel_backward.h | 202 +++++++++++++----- 3 files changed, 165 insertions(+), 51 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_backward_generic.cu b/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_backward_generic.cu index f3fb29ff5d..c5b6e37788 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_backward_generic.cu +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_backward_generic.cu @@ -104,9 +104,11 @@ mem_efficient_attention_backward_cutlass( int64_t B = query.size(0); int64_t M = query.size(1); + int64_t Mkv = key.size(1); int64_t N = key.size(1); int64_t nH = query.size(2); int64_t K = query.size(3); + int64_t Kv = value.size(3); // It does not make sense to use that in practice, // but let's still make sure we are correct @@ -133,6 +135,7 @@ mem_efficient_attention_backward_cutlass( grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); } + at::Tensor workspace; auto launchKernel = [&](auto _k, int computeCapability) { using Kernel = decltype(_k); @@ -205,6 +208,12 @@ mem_efficient_attention_backward_cutlass( ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + int64_t size_bytes = p.workspace_size(); + if (size_bytes) { + workspace = at::empty( + {size_bytes}, query.options().dtype(at::ScalarType::Byte)); + p.workspace = (float*)workspace.data_ptr(); + } Kernel::check_supported(p); constexpr auto kernel_fn = attention_kernel_backward_batched; diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h b/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h index 48994e2ab5..3ac85c3d8f 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h @@ -105,6 +105,11 @@ constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { return (n + m - 1) / m; } +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + //////////////////////////////////////////////////////////////////////////////// // Determine the type of GEMM we do (TensorCores or not, Shapes ...) // TODO: Maybe we could rely on Cutlass's DefaultGemm templates diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h b/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h index da255b562a..e107e507c9 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h @@ -39,7 +39,59 @@ using namespace gemm_kernel_utils; + + namespace { + +template +struct GmemTile { + // 128bits per thread + using AccessType = cutlass::Array; + static constexpr int32_t kBytes = sizeof(AccessType); + static constexpr int32_t kStride = kNumThreads * AccessType::kElements; + static constexpr int32_t kNumIters = FragmentType::kElements / AccessType::kElements; + static constexpr int32_t kElementsStored = kNumThreads * FragmentType::kElements; + static_assert(FragmentType::kElements % AccessType::kElements == 0, "fragment not aligned on 128 bits"); + + float* __restrict__ ptr; + + CUTLASS_DEVICE void prefetch(int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast(ptr + thread_id * AccessType::kElements + i * kStride); + uint64_t addr = (uint64_t)((void*)gmem_ptr); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + } + + CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast(ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + cutlass::arch::global_load(sub_fragment, gmem_ptr, true); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + fragment[i * AccessType::kElements + j] = sub_fragment[j]; + } + } + } + + CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast(ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + sub_fragment[j] = fragment[i * AccessType::kElements + j]; + } + cutlass::arch::global_store(sub_fragment, gmem_ptr, true); + } + } +}; + + template constexpr int getWarpsPerSm() { bool is_half = !std::is_same::value; @@ -62,6 +114,7 @@ template < struct AttentionBackwardKernel { using scalar_t = scalar_t_; using output_t = scalar_t; + using output_accum_t = float; using lse_scalar_t = float; using accum_t = float; using ArchTag = ArchTag_; @@ -81,6 +134,9 @@ struct AttentionBackwardKernel { output_t* grad_query_ptr; // [Mq, nH, K] output_t* grad_key_ptr; // [Mk, nH, K] output_t* grad_value_ptr; // [Mk, nH, Kv] + // Accumulators + output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] + output_accum_t* workspace_end = nullptr; // Only used in debug mode // Scale accum_t scale; @@ -137,7 +193,7 @@ struct AttentionBackwardKernel { constexpr int32_t kAlignLSE = 32; // block size of backward auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; - int32_t batch_id = blockIdx.z; + int64_t batch_id = blockIdx.z; int32_t head_id = blockIdx.y; query_ptr += batch_id * q_strideB + head_id * q_strideH; @@ -175,6 +231,16 @@ struct AttentionBackwardKernel { grad_query_ptr = warp_uniform(grad_query_ptr); grad_key_ptr = warp_uniform(grad_key_ptr); grad_value_ptr = warp_uniform(grad_value_ptr); + + if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { + assert(workspace != nullptr); + // format: [B, H, M, K] + workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); + workspace_end = workspace + workspace_strideBH(); + workspace = warp_uniform(workspace); + } else { + workspace = nullptr; + } } __host__ dim3 getBlocksGrid() const { @@ -183,6 +249,31 @@ struct AttentionBackwardKernel { __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { + if (!kNeedsAccumGradQ) { + return 0; + } + return align_up(num_queries, (int32_t)kBlockSizeI) * align_up(head_dim, (int32_t)kBlockSizeJ); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { + if (!kNeedsAccumGradK) { + return 0; + } + return align_up(num_keys, (int32_t)kBlockSizeJ) * align_up(head_dim, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { + if (!kNeedsAccumGradV) { + return 0; + } + return align_up(num_keys, (int32_t)kBlockSizeJ) * align_up(head_dim_value, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { + return align_up(workspace_elements_gq() + workspace_elements_gk() + workspace_elements_gv(), 4L); + } + CUTLASS_HOST_DEVICE int64_t workspace_size() const { + // Returns memory we need as buffer in gmem to run this kernel + return num_batches * num_heads * workspace_strideBH() * sizeof(float); + } }; // Block I @@ -221,6 +312,13 @@ struct AttentionBackwardKernel { static constexpr bool kKernelComputesDelta = kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); + static constexpr bool kNeedsAccumGradQ = + !std::is_same::value; + static constexpr bool kNeedsAccumGradK = + !kOutputInRF && !std::is_same::value; + static constexpr bool kNeedsAccumGradV = + !kOutputInRF && !std::is_same::value; + // Launch bounds static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; static constexpr int64_t kMinBlocksPerSm = @@ -420,6 +518,7 @@ struct AttentionBackwardKernel { using OutputTileIterator = typename cutlass::epilogue::threadblock::MakePrefetchableIterator< typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileIterator = GmemTile; }; struct MatmulGradK { // grad_k <- tmp.transpose(-2, -1) @ q_i @@ -712,9 +811,9 @@ struct AttentionBackwardKernel { extern __shared__ char smem_buffer[]; SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); - auto getQueryStart = [&](int32_t key_start) { + auto getQueryStart = [&](int32_t key_start) -> int32_t { if (p.causal) { - return key_start; + return (key_start / kBlockSizeI) * kBlockSizeI; } return 0; }; @@ -1136,18 +1235,12 @@ struct AttentionBackwardKernel { for (int col = 0; col < p.head_dim; col += MatmulGradQ::ThreadblockShape::kN) { using Mma = typename MatmulGradQ::Mma; + using AccumTileIterator = typename MatmulGradQ::AccumTileIterator; cutlass::gemm::GemmCoord problem_size( num_queries_in_block, false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, num_keys_in_block); - auto createEpilogueIter = [&]() { - return typename MatmulGradQ::OutputTileIterator( - typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, - p.grad_query_ptr + query_start * p.gQ_strideM() + col, - {problem_size.m(), problem_size.n()}, - thread_id); - }; // k_j typename Mma::IteratorB iterator_B( @@ -1168,13 +1261,23 @@ struct AttentionBackwardKernel { typename Mma::FragmentC accum; - accum.clear(); + bool isFirst = key_start == 0; + float* __restrict__ gmem_accum_ptr = p.workspace; + int col_id = col / MatmulGradQ::ThreadblockShape::kN; + int storage_id = (col_id + query_start / kBlockSizeI * ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN)); + gmem_accum_ptr += storage_id * AccumTileIterator::kElementsStored; + assert(gmem_accum_ptr < p.workspace_end); + if (isFirst || !kNeedsAccumGradQ) { + accum.clear(); + } else { + AccumTileIterator gmem_tile{gmem_accum_ptr}; + gmem_tile.ptr = gmem_accum_ptr; + gmem_tile.load(accum, thread_id); + } auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - // Start prefetching output tile now to make the epilogue faster - createEpilogueIter().prefetch_all(); // Compute threadblock-scoped matrix multiply-add __syncthreads(); mma.set_prologue_done(kPrologueGQ); @@ -1186,44 +1289,27 @@ struct AttentionBackwardKernel { } // Output results - typename MatmulGradQ::OutputTileIterator output_it = createEpilogueIter(); - DISPATCH_BOOL( - key_start == 0, kIsFirst, ([&]() { - using DefaultEpilogue = typename MatmulGradQ::DefaultEpilogue; - using DefaultOutputOp = typename MatmulGradQ::DefaultOutputOp; - static constexpr auto ScaleType = kIsFirst - ? cutlass::epilogue::thread::ScaleType::Nothing - : cutlass::epilogue::thread::ScaleType::NoBetaScaling; - using EpilogueOutputOp = - typename cutlass::epilogue::thread::LinearCombination< - typename DefaultOutputOp::ElementOutput, - DefaultOutputOp::kCount, - typename DefaultOutputOp::ElementAccumulator, - typename DefaultOutputOp::ElementCompute, - ScaleType>; - using Epilogue = - typename cutlass::epilogue::threadblock::EpiloguePipelined< - typename DefaultEpilogue::Shape, - typename Mma::Operator, - DefaultEpilogue::kPartitionsK, - typename MatmulGradQ::OutputTileIterator, - typename DefaultEpilogue::AccumulatorFragmentIterator, - typename DefaultEpilogue::WarpTileIterator, - typename DefaultEpilogue::SharedLoadIterator, - EpilogueOutputOp, - typename DefaultEpilogue::Padding, - DefaultEpilogue::kFragmentsPerIteration, - true // IterationsUnroll - >; - EpilogueOutputOp rescale({1, 1}); - Epilogue epilogue( - isLastColumn ? shared_storage.gradQ_epilogue_lastIter() - : shared_storage.gradQ_epilogue(), - thread_id, - warp_id, - lane_id); - epilogue(rescale, output_it, accum, output_it); - })); + int32_t next_query, next_key; + incrIteration(p, p.num_queries, key_start, next_query, next_key); + bool isLast = + (p.causal && next_query > query_start) || next_key >= p.num_keys; + if (kNeedsAccumGradQ && !isLast) { + AccumTileIterator gmem_tile{gmem_accum_ptr}; + gmem_tile.ptr = gmem_accum_ptr; + gmem_tile.store(accum, thread_id); + } else { + typename MatmulGradQ::OutputTileIterator output_it( + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, + {problem_size.m(), problem_size.n()}, + thread_id); + accumulateInGmem( + isLastColumn ? shared_storage.gradQ_epilogue_lastIter() + : shared_storage.gradQ_epilogue(), + accum, + output_it, + isFirst || kNeedsAccumGradQ); + } } ///////////////////////////////////////////////////////////////////////////////////////////////// // GradK matmul @@ -1319,6 +1405,20 @@ struct AttentionBackwardKernel { } } + static CUTLASS_DEVICE void incrIteration( + Params const& p, + int32_t query_start, + int32_t key_start, + int32_t& next_query, + int32_t& next_key) { + next_query = query_start + kBlockSizeI; + next_key = key_start; + if (next_query >= p.num_queries) { + next_key = key_start + kBlockSizeJ; + next_query = p.causal ? (next_key / kBlockSizeI) * kBlockSizeI : 0; + } + } + template static CUTLASS_DEVICE void prologueQkNextIteration( SharedStorage& shared_storage,