Skip to content

Commit

Permalink
fctls_bflsh: New OP that combines cutlass's fw + flash's bw
Browse files Browse the repository at this point in the history
ghstack-source-id: 8c2c5e6c4aa8c7cdf9bc972b5434e6c0fed58c6f
Pull Request resolved: #469
  • Loading branch information
danthe3rd committed Oct 7, 2022
1 parent a09e29d commit df167a1
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 19 deletions.
11 changes: 9 additions & 2 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
# Crashes on Flash, Errors on Cutlass
# shapes.append((1, 1, 64000, 300, 128, 128))
# Add some random shapes
if op is xformers.ops.MemoryEfficientAttentionCutlassOp:
if op in [
xformers.ops.MemoryEfficientAttentionCutlassOp,
xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp,
]:
K_CHOICES = [8 * i for i in range(1, 256 // 8)]
r = random.Random(0)
for _ in range(20):
Expand All @@ -75,6 +78,7 @@ def _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(**kwargs):
xformers.ops.MemoryEfficientAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassOp,
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp,
]:
for shape in generate_test_shapes_B_Mq_Mkv_H_K_Kv(op, **kwargs):
for device in _devices:
Expand Down Expand Up @@ -415,7 +419,10 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv):
k,
kv,
) = op_device_dtype_B_Mq_Mkv_H_K_Kv
if op.FORWARD_OPERATOR is None:
if (
op.FORWARD_OPERATOR is None
or op is xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp
):
return
query, key, value, attn_bias = create_tensors(
*op_device_dtype_B_Mq_Mkv_H_K_Kv, fmt="BMK"
Expand Down
96 changes: 79 additions & 17 deletions xformers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import math
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union

import torch
Expand Down Expand Up @@ -427,16 +428,17 @@ def forward_no_grad(
)

@classmethod
def forward(cls, ctx, query, key, value, attn_bias, p):
causal = isinstance(attn_bias, LowerTriangularMask)
return_softmax = False

def prepare_inputs(
cls, ctx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
):
batch = query.shape[0]
seqlen_q = query.shape[1]
seqlen_k = key.shape[1]
num_heads = query.shape[2]
head_dim_q = query.shape[3]
head_dim_v = value.shape[3]
ctx.max_seqlen_q = seqlen_q
ctx.max_seqlen_k = seqlen_k

cu_seqlens_k = torch.arange(
0,
Expand All @@ -458,12 +460,22 @@ def forward(cls, ctx, query, key, value, attn_bias, p):

# Initially we have `query.shape = [batch, seqlen, head_dim_q]`
# We want format `[batch * seqlen, num_heads, head_dim_q]`
query_api_input_shape = query.shape
key_api_input_shape = key.shape
value_api_input_shape = value.shape
ctx.query_api_input_shape = query.shape
ctx.key_api_input_shape = key.shape
ctx.value_api_input_shape = value.shape
query = query.reshape([batch * seqlen_q, num_heads, head_dim_q])
key = key.reshape([batch * seqlen_k, num_heads, head_dim_q])
value = value.reshape([batch * seqlen_k, num_heads, head_dim_v])
return query, key, value, cu_seqlens_k, cu_seqlens_q

@classmethod
def forward(cls, ctx, query, key, value, attn_bias, p):
causal = isinstance(attn_bias, LowerTriangularMask)
return_softmax = False
ctx_flash = ctx if ctx is not None else SimpleNamespace()
query, key, value, cu_seqlens_k, cu_seqlens_q = cls.prepare_inputs(
ctx_flash, query, key, value
)

# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if p > 0 else None
Expand All @@ -474,8 +486,8 @@ def forward(cls, ctx, query, key, value, attn_bias, p):
value,
cu_seqlens_q,
cu_seqlens_k,
seqlen_q,
seqlen_k,
ctx_flash.max_seqlen_q,
ctx_flash.max_seqlen_k,
p,
softmax_scale,
causal=causal,
Expand All @@ -493,18 +505,17 @@ def forward(cls, ctx, query, key, value, attn_bias, p):
rng_state,
)
ctx.dropout_p = p
ctx.max_seqlen_q = seqlen_q
ctx.max_seqlen_k = seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.kernel_output_shape = out.shape
ctx.query_api_input_shape = query_api_input_shape
ctx.key_api_input_shape = key_api_input_shape
ctx.value_api_input_shape = value_api_input_shape
return out

@classmethod
def backward(cls, ctx, grad):
return cls._backward(ctx, grad, ctx.saved_tensors)

@classmethod
def _backward(cls, ctx, grad, saved_tensors):
(
q,
k,
Expand All @@ -514,7 +525,7 @@ def backward(cls, ctx, grad):
cu_seqlens_q,
cu_seqlens_k,
rng_state,
) = ctx.saved_tensors
) = saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
Expand Down Expand Up @@ -641,6 +652,38 @@ def _flash_attn_backward(
return dq, dk, dv, softmax_d


class MemoryEfficientAttentionCutlassFwdFlashBwOp(MemoryEfficientAttentionCutlassOp):
FW_OP = MemoryEfficientAttentionCutlassOp
BW_OP = MemoryEfficientAttentionFlashAttentionOp
NAME = "fctls_bflsh"

@classmethod
def supports(cls, d: "AttentionOpDispatch") -> bool:
return cls.FW_OP.supports(d) and cls.BW_OP.supports(d)

@classmethod
def backward(cls, ctx, grad):
query, key, value, lse, out = ctx.saved_tensors
ctx_flash = SimpleNamespace()

ctx_flash.causal = ctx.causal
ctx_flash.dropout_p = 0.0
query, key, value, cu_seqlens_k, cu_seqlens_q = cls.BW_OP.prepare_inputs(
ctx_flash, query, key, value
)
ctx_flash.kernel_output_shape = (query.shape[0], query.shape[1], value.shape[2])
ctx_flash.softmax_scale = query.shape[-1] ** (-0.5)
rng_state = None

out = out.reshape(ctx_flash.kernel_output_shape)
grad = grad.reshape(ctx_flash.kernel_output_shape)
return cls.BW_OP._backward(
ctx_flash,
grad,
[query, key, value, out, lse, cu_seqlens_q, cu_seqlens_k, rng_state],
)


@dataclass
class AttentionOpDispatch:
dtype: torch.dtype
Expand All @@ -651,18 +694,32 @@ class AttentionOpDispatch:
kv_len: int
q_len: int
kv: int = -1
batch_size: int = -1
num_heads: int = 1

def __post_init__(self):
if self.kv == -1:
self.kv = self.k

def _is_cutlass_fwd_faster_than_flash(self) -> bool:
# Very small batch sizes - if batch size specified
if self.batch_size > 0:
threads_flash = self.batch_size * self.num_heads
threads_cutlass = threads_flash * (self.q_len // 64)
if threads_flash < 60 and (threads_cutlass // 2) >= threads_flash:
return True
# Large values of K
return max(self.k, self.kv) == 128

@property
def op(self) -> Type[AttentionOpBase]:
priority_list_ops: List[Type[AttentionOpBase]] = [
MemoryEfficientAttentionFlashAttentionOp,
MemoryEfficientAttentionCutlassOp,
MemoryEfficientAttentionOp,
]
if self._is_cutlass_fwd_faster_than_flash():
priority_list_ops.insert(0, MemoryEfficientAttentionCutlassFwdFlashBwOp)
for op in priority_list_ops:
if op.supports(self):
return op
Expand All @@ -677,15 +734,20 @@ def from_arguments(
attn_bias: Optional[Union[torch.Tensor, AttentionMask]] = None,
p: float = 0.0,
) -> "AttentionOpDispatch":
B, H = query.shape[0], 1
if query.ndim == 4:
H = query.shape[2]
return AttentionOpDispatch(
dtype=query.dtype,
device=query.device,
k=query.shape[-1],
kv=value.shape[-1],
has_dropout=p > 0.0,
attn_bias_type=type(attn_bias),
kv_len=value.shape[-2],
q_len=query.shape[-2],
kv_len=value.shape[1],
q_len=query.shape[1],
batch_size=B,
num_heads=H,
)


Expand Down

0 comments on commit df167a1

Please sign in to comment.