Skip to content

Commit

Permalink
bwaccf32: Accumulate in f32 for bw
Browse files Browse the repository at this point in the history
ghstack-source-id: 48369de3f8b94eb3c190ac2b0a1b3ddf6003e5ff
Pull Request resolved: #467
  • Loading branch information
danthe3rd committed Nov 28, 2022
1 parent b516aec commit 1f37d54
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ mem_efficient_attention_backward_cutlass(

int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t Mkv = key.size(1);
int64_t N = key.size(1);
int64_t nH = query.size(2);
int64_t K = query.size(3);
int64_t Kv = value.size(3);

// It does not make sense to use that in practice,
// but let's still make sure we are correct
Expand All @@ -133,6 +135,7 @@ mem_efficient_attention_backward_cutlass(
grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key);
grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value);
}
at::Tensor workspace;

auto launchKernel = [&](auto _k, int computeCapability) {
using Kernel = decltype(_k);
Expand Down Expand Up @@ -205,6 +208,12 @@ mem_efficient_attention_backward_cutlass(
ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2));
ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2));

int64_t size_bytes = p.workspace_size();
if (size_bytes) {
workspace = at::empty(
{size_bytes}, query.options().dtype(at::ScalarType::Byte));
p.workspace = (float*)workspace.data_ptr();
}
Kernel::check_supported(p);

constexpr auto kernel_fn = attention_kernel_backward_batched<Kernel>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}

template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
return ((n + m - 1) / m) * m;
}

////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,59 @@

using namespace gemm_kernel_utils;



namespace {

template <typename FragmentType, int32_t kNumThreads>
struct GmemTile {
// 128bits per thread
using AccessType = cutlass::Array<float, 4>;
static constexpr int32_t kBytes = sizeof(AccessType);
static constexpr int32_t kStride = kNumThreads * AccessType::kElements;
static constexpr int32_t kNumIters = FragmentType::kElements / AccessType::kElements;
static constexpr int32_t kElementsStored = kNumThreads * FragmentType::kElements;
static_assert(FragmentType::kElements % AccessType::kElements == 0, "fragment not aligned on 128 bits");

float* __restrict__ ptr;

CUTLASS_DEVICE void prefetch(int thread_id) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNumIters; ++i) {
AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(ptr + thread_id * AccessType::kElements + i * kStride);
uint64_t addr = (uint64_t)((void*)gmem_ptr);
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
}
}

CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNumIters; ++i) {
AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(ptr + thread_id * AccessType::kElements + i * kStride);
AccessType sub_fragment;
cutlass::arch::global_load<AccessType, kBytes>(sub_fragment, gmem_ptr, true);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccessType::kElements; ++j) {
fragment[i * AccessType::kElements + j] = sub_fragment[j];
}
}
}

CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNumIters; ++i) {
AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(ptr + thread_id * AccessType::kElements + i * kStride);
AccessType sub_fragment;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccessType::kElements; ++j) {
sub_fragment[j] = fragment[i * AccessType::kElements + j];
}
cutlass::arch::global_store<AccessType, kBytes>(sub_fragment, gmem_ptr, true);
}
}
};


template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() {
bool is_half = !std::is_same<scalar_t, float>::value;
Expand All @@ -62,6 +114,7 @@ template <
struct AttentionBackwardKernel {
using scalar_t = scalar_t_;
using output_t = scalar_t;
using output_accum_t = float;
using lse_scalar_t = float;
using accum_t = float;
using ArchTag = ArchTag_;
Expand All @@ -81,6 +134,9 @@ struct AttentionBackwardKernel {
output_t* grad_query_ptr; // [Mq, nH, K]
output_t* grad_key_ptr; // [Mk, nH, K]
output_t* grad_value_ptr; // [Mk, nH, Kv]
// Accumulators
output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv]
output_accum_t* workspace_end = nullptr; // Only used in debug mode

// Scale
accum_t scale;
Expand Down Expand Up @@ -137,7 +193,7 @@ struct AttentionBackwardKernel {
constexpr int32_t kAlignLSE = 32; // block size of backward
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;

int32_t batch_id = blockIdx.z;
int64_t batch_id = blockIdx.z;
int32_t head_id = blockIdx.y;

query_ptr += batch_id * q_strideB + head_id * q_strideH;
Expand Down Expand Up @@ -175,6 +231,16 @@ struct AttentionBackwardKernel {
grad_query_ptr = warp_uniform(grad_query_ptr);
grad_key_ptr = warp_uniform(grad_key_ptr);
grad_value_ptr = warp_uniform(grad_value_ptr);

if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) {
assert(workspace != nullptr);
// format: [B, H, M, K]
workspace += (batch_id * num_heads + head_id) * workspace_strideBH();
workspace_end = workspace + workspace_strideBH();
workspace = warp_uniform(workspace);
} else {
workspace = nullptr;
}
}

__host__ dim3 getBlocksGrid() const {
Expand All @@ -183,6 +249,31 @@ struct AttentionBackwardKernel {
__host__ dim3 getThreadsGrid() const {
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
}
CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const {
if (!kNeedsAccumGradQ) {
return 0;
}
return align_up(num_queries, (int32_t)kBlockSizeI) * align_up(head_dim, (int32_t)kBlockSizeJ);
}
CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const {
if (!kNeedsAccumGradK) {
return 0;
}
return align_up(num_keys, (int32_t)kBlockSizeJ) * align_up(head_dim, (int32_t)kBlockSizeI);
}
CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const {
if (!kNeedsAccumGradV) {
return 0;
}
return align_up(num_keys, (int32_t)kBlockSizeJ) * align_up(head_dim_value, (int32_t)kBlockSizeI);
}
CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const {
return align_up(workspace_elements_gq() + workspace_elements_gk() + workspace_elements_gv(), 4L);
}
CUTLASS_HOST_DEVICE int64_t workspace_size() const {
// Returns memory we need as buffer in gmem to run this kernel
return num_batches * num_heads * workspace_strideBH() * sizeof(float);
}
};

// Block I
Expand Down Expand Up @@ -221,6 +312,13 @@ struct AttentionBackwardKernel {
static constexpr bool kKernelComputesDelta =
kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70);

static constexpr bool kNeedsAccumGradQ =
!std::is_same<output_accum_t, output_t>::value;
static constexpr bool kNeedsAccumGradK =
!kOutputInRF && !std::is_same<output_accum_t, output_t>::value;
static constexpr bool kNeedsAccumGradV =
!kOutputInRF && !std::is_same<output_accum_t, output_t>::value;

// Launch bounds
static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int64_t kMinBlocksPerSm =
Expand Down Expand Up @@ -420,6 +518,7 @@ struct AttentionBackwardKernel {
using OutputTileIterator =
typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
typename DefaultEpilogue::OutputTileIterator>::Iterator;
using AccumTileIterator = GmemTile<typename Mma::FragmentC, kNumThreads>;
};
struct MatmulGradK {
// grad_k <- tmp.transpose(-2, -1) @ q_i
Expand Down Expand Up @@ -712,9 +811,9 @@ struct AttentionBackwardKernel {
extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);

auto getQueryStart = [&](int32_t key_start) {
auto getQueryStart = [&](int32_t key_start) -> int32_t {
if (p.causal) {
return key_start;
return (key_start / kBlockSizeI) * kBlockSizeI;
}
return 0;
};
Expand Down Expand Up @@ -1136,18 +1235,12 @@ struct AttentionBackwardKernel {
for (int col = 0; col < p.head_dim;
col += MatmulGradQ::ThreadblockShape::kN) {
using Mma = typename MatmulGradQ::Mma;
using AccumTileIterator = typename MatmulGradQ::AccumTileIterator;

cutlass::gemm::GemmCoord problem_size(
num_queries_in_block,
false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col,
num_keys_in_block);
auto createEpilogueIter = [&]() {
return typename MatmulGradQ::OutputTileIterator(
typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()},
p.grad_query_ptr + query_start * p.gQ_strideM() + col,
{problem_size.m(), problem_size.n()},
thread_id);
};

// k_j
typename Mma::IteratorB iterator_B(
Expand All @@ -1168,13 +1261,23 @@ struct AttentionBackwardKernel {

typename Mma::FragmentC accum;

accum.clear();
bool isFirst = key_start == 0;
float* __restrict__ gmem_accum_ptr = p.workspace;
int col_id = col / MatmulGradQ::ThreadblockShape::kN;
int storage_id = (col_id + query_start / kBlockSizeI * ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN));
gmem_accum_ptr += storage_id * AccumTileIterator::kElementsStored;
assert(gmem_accum_ptr < p.workspace_end);
if (isFirst || !kNeedsAccumGradQ) {
accum.clear();
} else {
AccumTileIterator gmem_tile{gmem_accum_ptr};
gmem_tile.ptr = gmem_accum_ptr;
gmem_tile.load(accum, thread_id);
}

auto gemm_k_iterations =
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;

// Start prefetching output tile now to make the epilogue faster
createEpilogueIter().prefetch_all();
// Compute threadblock-scoped matrix multiply-add
__syncthreads();
mma.set_prologue_done(kPrologueGQ);
Expand All @@ -1186,44 +1289,27 @@ struct AttentionBackwardKernel {
}

// Output results
typename MatmulGradQ::OutputTileIterator output_it = createEpilogueIter();
DISPATCH_BOOL(
key_start == 0, kIsFirst, ([&]() {
using DefaultEpilogue = typename MatmulGradQ::DefaultEpilogue;
using DefaultOutputOp = typename MatmulGradQ::DefaultOutputOp;
static constexpr auto ScaleType = kIsFirst
? cutlass::epilogue::thread::ScaleType::Nothing
: cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::LinearCombination<
typename DefaultOutputOp::ElementOutput,
DefaultOutputOp::kCount,
typename DefaultOutputOp::ElementAccumulator,
typename DefaultOutputOp::ElementCompute,
ScaleType>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MatmulGradQ::OutputTileIterator,
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true // IterationsUnroll
>;
EpilogueOutputOp rescale({1, 1});
Epilogue epilogue(
isLastColumn ? shared_storage.gradQ_epilogue_lastIter()
: shared_storage.gradQ_epilogue(),
thread_id,
warp_id,
lane_id);
epilogue(rescale, output_it, accum, output_it);
}));
int32_t next_query, next_key;
incrIteration(p, p.num_queries, key_start, next_query, next_key);
bool isLast =
(p.causal && next_query > query_start) || next_key >= p.num_keys;
if (kNeedsAccumGradQ && !isLast) {
AccumTileIterator gmem_tile{gmem_accum_ptr};
gmem_tile.ptr = gmem_accum_ptr;
gmem_tile.store(accum, thread_id);
} else {
typename MatmulGradQ::OutputTileIterator output_it(
typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()},
p.grad_query_ptr + query_start * p.gQ_strideM() + col,
{problem_size.m(), problem_size.n()},
thread_id);
accumulateInGmem<MatmulGradQ>(
isLastColumn ? shared_storage.gradQ_epilogue_lastIter()
: shared_storage.gradQ_epilogue(),
accum,
output_it,
isFirst || kNeedsAccumGradQ);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// GradK matmul
Expand Down Expand Up @@ -1319,6 +1405,20 @@ struct AttentionBackwardKernel {
}
}

static CUTLASS_DEVICE void incrIteration(
Params const& p,
int32_t query_start,
int32_t key_start,
int32_t& next_query,
int32_t& next_key) {
next_query = query_start + kBlockSizeI;
next_key = key_start;
if (next_query >= p.num_queries) {
next_key = key_start + kBlockSizeJ;
next_query = p.causal ? (next_key / kBlockSizeI) * kBlockSizeI : 0;
}
}

template <bool kForceReloadK>
static CUTLASS_DEVICE void prologueQkNextIteration(
SharedStorage& shared_storage,
Expand Down

0 comments on commit 1f37d54

Please sign in to comment.