Skip to content

Commit

Permalink
MemoryEff attention forward: Properly fuse matmul and enable TensorCo…
Browse files Browse the repository at this point in the history
…res on the second matmul (facebookresearch#368)

* Generic backwards

* Guard backward to sm75 only

* bounds checking for gradV

* clang-format

* Fused gemm working for Sm80/Sm75 f16/f32

* WIP

* Volta TensorOp for f16

* Working on A100 again

* SIMT working

* Code cleanup 1

* Code cleanup2

* BUGFIX for shared memory limit

* Remove code

* clang-format

* Remove code again

* Remove draft of backward

* Enforce alignment for fp16

* Fix tests

* Fix constraint on seq length when not using tensorcores

* Fix alignment requirements for V100/tensorcores

* Clang-format

* Update xformers/components/attention/csrc/cuda/attention_forward_generic.cu

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>

* Address comments from fmassa

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
3 people authored Jul 25, 2022
1 parent ab47700 commit 96c98ec
Show file tree
Hide file tree
Showing 12 changed files with 2,656 additions and 540 deletions.
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def get_extensions():
if cuda_version >= 1102:
nvcc_flags += ["--threads", "4"]
extra_compile_args["nvcc"] = nvcc_flags
if cuda_version >= 1100 and os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") == "0":
if (
cuda_version >= 1100
and os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") == "0"
):
ext_modules += get_flash_attention_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_memory_efficient_attention(
).float()
ref = ref_attention(query, key, value, attn_bias)

assert_allclose(out, ref, atol=op.FORWARD_ERROR_ATOL)
assert_allclose(out, ref, atol=op.FORWARD_ERROR_ATOL[dtype])


@pytest.mark.parametrize("k_len", [5, 6, 32])
Expand Down
9 changes: 5 additions & 4 deletions third_party/cutlass/include/cutlass/arch/mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ template <
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Inner product operator
typename Operator_
>
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> {
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> {

using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = Operator_;
using ElementC = ElementC_;

CUTLASS_HOST_DEVICE
void operator()(
Expand Down Expand Up @@ -218,8 +219,8 @@ struct SparseMma;
#include "cutlass/arch/mma_sm50.h"
#include "cutlass/arch/mma_sm60.h"
#include "cutlass/arch/mma_sm61.h"
#include "cutlass/arch/mma_sm70.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm70.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/arch/mma_sparse_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
78 changes: 40 additions & 38 deletions third_party/cutlass/include/cutlass/arch/mma_sm50.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, L

using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
using ElementC = float;

CUTLASS_HOST_DEVICE
void operator()(
Expand Down Expand Up @@ -144,12 +145,12 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
complex<float>,
LayoutA,
complex<float>,
LayoutB,
complex<float>,
LayoutC,
complex<float>,
LayoutA,
complex<float>,
LayoutB,
complex<float>,
LayoutC,
OpMultiplyAdd> {

using Shape = gemm::GemmShape<1, 1, 1>;
Expand Down Expand Up @@ -184,12 +185,12 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
complex<float>,
LayoutA,
float,
LayoutB,
complex<float>,
LayoutC,
complex<float>,
LayoutA,
float,
LayoutB,
complex<float>,
LayoutC,
OpMultiplyAdd> {

using Shape = gemm::GemmShape<1, 1, 1>;
Expand Down Expand Up @@ -222,12 +223,12 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
float,
LayoutA,
complex<float>,
LayoutB,
complex<float>,
LayoutC,
float,
LayoutA,
complex<float>,
LayoutB,
complex<float>,
LayoutC,
OpMultiplyAdd> {

using Shape = gemm::GemmShape<1, 1, 1>;
Expand Down Expand Up @@ -260,12 +261,12 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
complex<double>,
LayoutA,
complex<double>,
LayoutB,
complex<double>,
LayoutC,
complex<double>,
LayoutA,
complex<double>,
LayoutB,
complex<double>,
LayoutC,
OpMultiplyAdd> {

using Shape = gemm::GemmShape<1, 1, 1>;
Expand Down Expand Up @@ -298,12 +299,12 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
complex<double>,
LayoutA,
double,
LayoutB,
complex<double>,
LayoutC,
complex<double>,
LayoutA,
double,
LayoutB,
complex<double>,
LayoutC,
OpMultiplyAdd> {

using Shape = gemm::GemmShape<1, 1, 1>;
Expand Down Expand Up @@ -334,12 +335,12 @@ template <
struct Mma<
gemm::GemmShape<1, 1, 1>,
1,
double,
LayoutA,
complex<double>,
LayoutB,
complex<double>,
LayoutC,
double,
LayoutA,
complex<double>,
LayoutB,
complex<double>,
LayoutC,
OpMultiplyAdd> {

using Shape = gemm::GemmShape<1, 1, 1>;
Expand Down Expand Up @@ -373,7 +374,8 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float,

using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;

using ElementC = float;

CUTLASS_HOST_DEVICE
void operator()(
Array<float, 1> &d,
Expand Down Expand Up @@ -412,7 +414,7 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<f
multiply_add<Element, Element, Element> op;
d[0] = op(a[0], b[0], c[0]);
}

};

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ class MmaVoltaTensorOpAccumulatorTileIterator {
InterleavedTile::kColumn / InstructionShape::kN>;
};

private:
public:

// Assume accumulator tile is multipile interleaved 32x32 tile.
static int const kElementsPerPartial = 4;
Expand Down
4 changes: 4 additions & 0 deletions xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import xformers.ops

torch.backends.cuda.matmul.allow_tf32 = False


def create_attn_bias(
bias_type, batch_size: int, q_len: int, kv_len: int, device, dtype
Expand Down Expand Up @@ -89,6 +91,7 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype):
k=K,
attn_bias_type=attn_bias_type,
has_dropout=False,
kv_len=M,
)
try:
op = dispatch.op if FORCE_OP is None else FORCE_OP
Expand Down Expand Up @@ -162,6 +165,7 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype):
k=K,
attn_bias_type=attn_bias_type,
has_dropout=False,
kv_len=M,
)
try:
op = dispatch.op if FORCE_OP is None else FORCE_OP
Expand Down
Loading

0 comments on commit 96c98ec

Please sign in to comment.