Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom attention bias #617

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
739f4bc
Add arbitrary bias matrix
b-albar Oct 19, 2023
f5d1aa1
Fix number of returns in bwd
b-albar Oct 30, 2023
f1c1689
Fix for sequence length not multiple of 8
b-albar Oct 30, 2023
57df738
Fix backward for non power of 2 sequences
b-albar Nov 2, 2023
eb27a75
Add gradients for the bias matrix
b-albar Nov 3, 2023
901a0da
FIx compilation error
b-albar Nov 8, 2023
801dfb7
Fix indexing mistake
b-albar Nov 13, 2023
80815ed
Replace torch.zeros with empty_like
b-albar Nov 13, 2023
90b5cd1
Add causal to test_attn_bias
b-albar Nov 13, 2023
993db98
Merge branch 'main' into trainable-bias
b-albar Dec 7, 2023
1d3820c
Fix interface when attn_bias is None
b-albar Dec 8, 2023
075fe8f
Fix previous merge
b-albar Dec 8, 2023
7e0402d
Remove syncthreads
b-albar Dec 11, 2023
4519f35
Optimize apply_attn_bias for some sequences lengths
b-albar Dec 13, 2023
2b0885b
Optim for 64 head size
b-albar Dec 15, 2023
9360ed8
Add boolean for attn_bias gradient
b-albar Jan 22, 2024
c5aea6b
Merge branch 'main' into trainable-bias
b-albar Jan 24, 2024
b4d9797
Merge branch 'main' into trainable-bias
b-albar Jan 26, 2024
b08217f
Merge branch 'main' into trainable-bias
b-albar Feb 4, 2024
83481bd
Fix merge mistakes
b-albar Feb 8, 2024
04da83f
Refactor attn_bias code
b-albar Feb 8, 2024
e971476
Various updates
b-albar Feb 12, 2024
d5554d4
Fix typo
b-albar Feb 20, 2024
7fe76b6
Disable alibi as a temporary workaround for memory error
b-albar Feb 20, 2024
f9ba2c2
Temporarily disable uneven_k and local
b-albar Feb 21, 2024
5819b36
Fix race condition
b-albar Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 153 additions & 10 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ void set_params_fprop(Flash_fwd_params &params,
float softmax_scale,
int window_size_left,
int window_size_right,
void* attn_bias,
uint32_t attn_bias_batch_stride,
uint32_t attn_bias_head_stride,
uint32_t attn_bias_q_stride,
bool seqlenq_ngroups_swapped=false) {

// Reset the parameters
Expand All @@ -65,6 +69,12 @@ void set_params_fprop(Flash_fwd_params &params,
params.o_row_stride = out.stride(-3);
params.o_head_stride = out.stride(-2);

// Attention biases
params.attn_bias_ptr = attn_bias;
params.attn_bias_batch_stride = attn_bias_batch_stride;
params.attn_bias_head_stride = attn_bias_head_stride;
params.attn_bias_q_stride = attn_bias_q_stride;

if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = q.stride(0);
params.k_batch_stride = k.stride(0);
Expand Down Expand Up @@ -168,7 +178,12 @@ void set_params_dgrad(Flash_bwd_params &params,
float softmax_scale,
int window_size_left,
int window_size_right,
bool deterministic) {
bool deterministic,
void* attn_bias,
void* attn_ds,
uint32_t attn_bias_batch_stride,
uint32_t attn_bias_head_stride,
uint32_t attn_bias_q_stride) {

set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
Expand All @@ -181,7 +196,11 @@ void set_params_dgrad(Flash_bwd_params &params,
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
attn_bias,
attn_bias_batch_stride,
attn_bias_head_stride,
attn_bias_q_stride);

// Set the pointers and strides.
params.do_ptr = dout.data_ptr();
Expand All @@ -197,6 +216,9 @@ void set_params_dgrad(Flash_bwd_params &params,
params.dk_head_stride = dk.stride(-2);
params.dv_head_stride = dv.stride(-2);

// Attention biases
params.attn_ds_ptr = attn_ds;

if (cu_seqlens_q_d == nullptr) {
params.do_batch_stride = dout.stride(0);
params.dq_batch_stride = dq.stride(0);
Expand Down Expand Up @@ -324,6 +346,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
bool is_causal,
int window_size_left,
int window_size_right,
const c10::optional<at::Tensor> &attn_bias, // batch_size x num_heads_k x seqlen_q x seqlen_k
const bool return_softmax,
c10::optional<at::Generator> gen_) {

Expand Down Expand Up @@ -383,6 +406,47 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);

// Attention biases
uint32_t attn_bias_batch_stride = 0;
uint32_t attn_bias_head_stride = 0;
uint32_t attn_bias_q_stride = 0;
at::Tensor attn_bias_padded;

#ifdef FLASHATTENTION_DISABLE_ATTN_BIAS
TORCH_CHECK(~attn_bias.has_value(),
"This flash attention build does not support custom attention biases.");
#endif

if (attn_bias.has_value()) {
TORCH_CHECK(attn_bias.value().is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(attn_bias.value().stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(attn_bias.value().dtype() == q_dtype, "attention bias and query must have the same dtype");

auto bias_sizes = attn_bias.value().sizes();
TORCH_CHECK((bias_sizes[0] == batch_size) || (bias_sizes[0] == 1), "First dimension of the bias should be 1 or batch size");
TORCH_CHECK((bias_sizes[1] == num_heads) || (bias_sizes[1] == 1), "First dimension of the bias should be 1 or num_heads");
TORCH_CHECK((bias_sizes[2] == seqlen_q) && (bias_sizes[3] == seqlen_k), "Last dimensions of bias should be seqlen_q and seqlen_k");
//CHECK_SHAPE(attn_bias.value(), batch_size, num_heads, seqlen_q, seqlen_k);

if ((seqlen_q % 8 != 0) || (seqlen_k % 8 != 0)) {
attn_bias_padded = torch::nn::functional::pad(attn_bias.value(), torch::nn::functional::PadFuncOptions({0, 8 - seqlen_k % 8, 0, 8 - seqlen_q % 8}));
} else {
attn_bias_padded = attn_bias.value();
}

attn_bias_batch_stride = attn_bias_padded.stride(0);
attn_bias_head_stride = attn_bias_padded.stride(1);
attn_bias_q_stride = attn_bias_padded.stride(2);

// Trick to support bias shape like (1, 1, seqlen_q, seqlen_k)
if ((bias_sizes[0] == 1) && (batch_size != 1)) {
attn_bias_batch_stride = 0;
}
if ((bias_sizes[1] == 1) && (num_heads != 1)) {
attn_bias_head_stride = 0;
}
}

at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
Expand Down Expand Up @@ -442,7 +506,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
attn_bias ? attn_bias_padded.data_ptr() : nullptr,
attn_bias_batch_stride,
attn_bias_head_stride,
attn_bias_q_stride);


set_params_splitkv(params, batch_size, num_heads,
Expand Down Expand Up @@ -489,7 +557,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
return {out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, p, rng_state};
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -653,7 +721,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
softmax_scale,
window_size_left,
window_size_right,
nullptr,
0,
0,
0,
seqlenq_ngroups_swapped);

if (seqlenq_ngroups_swapped) {
// Only apply split-k for decoding
set_params_splitkv(params, batch_size, num_heads,
Expand Down Expand Up @@ -732,6 +805,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
int window_size_left,
int window_size_right,
const bool deterministic,
const c10::optional<at::Tensor> &attn_bias,
const bool attn_bias_require_grad,
c10::optional<at::Tensor> &ds_, // batch_size x num_heads x seqlen_q x seqlen_k
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {

Expand Down Expand Up @@ -804,6 +880,57 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);

// Attention biases
uint32_t attn_bias_batch_stride = 0;
uint32_t attn_bias_head_stride = 0;
uint32_t attn_bias_q_stride = 0;
at::Tensor ds;

if (attn_bias.has_value()) {
TORCH_CHECK(attn_bias.value().is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(attn_bias.value().stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(attn_bias.value().dtype() == q_dtype, "attention bias and query must have the same dtype");

const int seqlen_q_round8 = round_multiple(seqlen_q, 8);
const int seqlen_k_round8 = round_multiple(seqlen_k, 8);

auto bias_sizes = attn_bias.value().sizes();
TORCH_CHECK((bias_sizes[0] == batch_size) || (bias_sizes[0] == 1), "First dimension of the bias should be 1 or batch size");
TORCH_CHECK((bias_sizes[1] == num_heads) || (bias_sizes[1] == 1), "First dimension of the bias should be 1 or num_heads");
TORCH_CHECK((bias_sizes[2] == seqlen_q_round8) && (bias_sizes[3] == seqlen_k_round8), "Last dimensions of bias should be seqlen_q and seqlen_k");
//CHECK_SHAPE(attn_bias.value(), batch_size, num_heads, seqlen_q_round8, seqlen_k_round8);

attn_bias_batch_stride = attn_bias.value().stride(0);
attn_bias_head_stride = attn_bias.value().stride(1);
attn_bias_q_stride = attn_bias.value().stride(2);

// Trick to support bias shape like (1, 1, seqlen_q, seqlen_k)
if ((bias_sizes[0] == 1) && (batch_size != 1)) {
attn_bias_batch_stride = 0;
}
if ((bias_sizes[1] == 1) && (num_heads != 1)) {
attn_bias_head_stride = 0;
}

if (attn_bias_require_grad) {
if (ds_.has_value()) {
ds = ds_.value();
TORCH_CHECK(ds.dtype() == q_dtype, "ds must have the same dtype as q");
CHECK_DEVICE(ds);
TORCH_CHECK(ds.stride(-1) == 1, "ds must have contiguous last dimension");
CHECK_SHAPE(ds, bias_sizes[0], bias_sizes[1], seqlen_q_round8, seqlen_k_round8);

TORCH_CHECK(ds.is_contiguous());
} else {
ds = torch::empty_like(attn_bias.value());
}

if (is_causal || ((attn_bias_batch_stride == 0) || (attn_bias_head_stride == 0))) {
ds.zero_();
}
}
}

at::Tensor dq, dk, dv;
if (dq_.has_value()) {
dq = dq_.value();
Expand Down Expand Up @@ -895,8 +1022,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_scale,
window_size_left,
window_size_right,
deterministic);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
deterministic,
attn_bias ? attn_bias->data_ptr() : nullptr,
attn_bias && attn_bias_require_grad ? ds.data_ptr() : nullptr,
attn_bias_batch_stride,
attn_bias_head_stride,
attn_bias_q_stride);

auto launch = &run_mha_bwd;

Expand Down Expand Up @@ -939,7 +1070,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
}

return { dq, dk, dv, softmax_d };
if (attn_bias && attn_bias_require_grad && ((seqlen_q % 8 != 0) || (seqlen_k % 8 != 0))) {
ds = ds.index({"...", torch::indexing::Slice(torch::indexing::None, seqlen_q), torch::indexing::Slice(torch::indexing::None, seqlen_k)});
}

return { dq, dk, dv, ds, softmax_d };
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -1144,8 +1279,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_scale,
window_size_left,
window_size_right,
deterministic);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
deterministic,
nullptr,
nullptr,
0,
0,
0);

auto launch = &run_mha_bwd;

Expand Down Expand Up @@ -1343,7 +1482,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
/*p_dropout=*/0.f,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
nullptr,
0,
0,
0);

at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) {
Expand Down
Loading