Skip to content

Commit

Permalink
bwaccf32: Accumulate in f32 for bw
Browse files Browse the repository at this point in the history
ghstack-source-id: d13ff3e407510b98ea4daf9cf38a74d966fa4d59
Pull Request resolved: #467
  • Loading branch information
danthe3rd committed Nov 29, 2022
1 parent b516aec commit 6f8e2f1
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ mem_efficient_attention_backward_cutlass(

int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t Mkv = key.size(1);
int64_t N = key.size(1);
int64_t nH = query.size(2);
int64_t K = query.size(3);
int64_t Kv = value.size(3);

// It does not make sense to use that in practice,
// but let's still make sure we are correct
Expand All @@ -133,6 +135,7 @@ mem_efficient_attention_backward_cutlass(
grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key);
grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value);
}
at::Tensor workspace;

auto launchKernel = [&](auto _k, int computeCapability) {
using Kernel = decltype(_k);
Expand Down Expand Up @@ -205,6 +208,12 @@ mem_efficient_attention_backward_cutlass(
ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2));
ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2));

int64_t size_bytes = p.workspace_size();
if (size_bytes) {
workspace =
at::empty({size_bytes}, query.options().dtype(at::ScalarType::Byte));
p.workspace = (float*)workspace.data_ptr();
}
Kernel::check_supported(p);

constexpr auto kernel_fn = attention_kernel_backward_batched<Kernel>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}

template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
return ((n + m - 1) / m) * m;
}

////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
Expand Down
Loading

0 comments on commit 6f8e2f1

Please sign in to comment.