Skip to content

Commit

Permalink
bwaccf32: Accumulate in f32 for bw
Browse files Browse the repository at this point in the history
ghstack-source-id: 29c8918b52a893018f19cebeb5367bdd6202405b
Pull Request resolved: #467
  • Loading branch information
danthe3rd committed Nov 29, 2022
1 parent 1515f77 commit 228e0cd
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 90 deletions.
13 changes: 3 additions & 10 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import math
import random
from dataclasses import dataclass
from typing import Any, Sequence, Type
Expand Down Expand Up @@ -258,21 +257,15 @@ def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:


def backward_error_atol(k, kv_len, q_len, dtype):
atol = 2e-4 + 2e-6 * k * kv_len * math.sqrt(q_len)
atol = 2e-4 + 2e-6 * k
rtol = 1e-4
if dtype is torch.half:
atol = 5e-2
atol = 8e-2
rtol = 2e-2
# TODO: Implement f32 accumulation for bw
# Longer sequences mean we iterate more and errors accumulate
atol *= 1.4 ** (max(q_len, kv_len) // 64)
if dtype is torch.bfloat16:
# I've seen (out=-1.9 and ref=-1.0 with flash)
atol = 0.5
atol = 0.7
rtol = 0.1
# TODO: Implement f32 accumulation for bw
# Longer sequences mean we iterate more and errors accumulate
atol *= 1.4 ** (max(q_len, kv_len) // 64)
return atol, rtol


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ mem_efficient_attention_backward_cutlass(

int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t Mkv = key.size(1);
int64_t N = key.size(1);
int64_t nH = query.size(2);
int64_t K = query.size(3);
int64_t Kv = value.size(3);

// It does not make sense to use that in practice,
// but let's still make sure we are correct
Expand All @@ -133,6 +135,7 @@ mem_efficient_attention_backward_cutlass(
grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key);
grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value);
}
at::Tensor workspace;

auto launchKernel = [&](auto _k, int computeCapability) {
using Kernel = decltype(_k);
Expand Down Expand Up @@ -205,6 +208,12 @@ mem_efficient_attention_backward_cutlass(
ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2));
ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2));

int64_t size_bytes = p.workspace_size();
if (size_bytes) {
workspace =
at::empty({size_bytes}, query.options().dtype(at::ScalarType::Byte));
p.workspace = (float*)workspace.data_ptr();
}
Kernel::check_supported(p);

constexpr auto kernel_fn = attention_kernel_backward_batched<Kernel>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}

template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
return ((n + m - 1) / m) * m;
}

////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
Expand Down
Loading

0 comments on commit 228e0cd

Please sign in to comment.