Skip to content

Commit

Permalink
Hunt correctness for extend attention + RPE
Browse files Browse the repository at this point in the history
Avoid flex_attention for RPE as it is unclear there is a correct implementation possible given the
limitations of the create_block_mask re conditionals.

Instead we use the manual torch implementation that is known to be correct.

As we update the test and the extend_attention_rpe to use a static max_rpe_context_length,
a new error appears that suggests some issue with the indexing in extend_attention_rpe.

Repro:
```
pytest tests/kernel/wave/attention/extend_attention_test.py --run-e2e -v -k "rpe"
```

Errors out with:
```
E               Diagnostics:
E               <stdin>:282:18: error: 'vector.gather' op operand iree-org#2 must be vector of integer or index values, but got 'index'
E                         %468 = "vector.gather"(%109, %39, %39, %467, %44) : (memref<?xf32, strided<[1], offset: ?>>, index, index, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
E                                ^
```

Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
  • Loading branch information
nicolasvasilache committed Feb 14, 2025
1 parent d1168c9 commit 8726a56
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 55 deletions.
12 changes: 5 additions & 7 deletions iree/turbine/kernel/wave/templates/extend_attention_rpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_extend_attention_rpe_kernel(
is_causal: Optional[bool] = False,
layer_scaling: Optional[float] = None,
num_waves: Optional[int] = 4,
max_rpe_context_length: Optional[int] = 0,
):
# Determine dtype of operands.
wave_input_dtype = torch_dtype_to_wave(input_dtype)
Expand Down Expand Up @@ -154,7 +155,8 @@ def get_extend_attention_rpe_kernel(
)

clip = sympy.Piecewise(
(d0 - d1, (d0 - d1 < MAX_EXTEND_SEQ_LEN) & (d0 - d1 > 0)), (0, True)
(d0 - d1, (d0 - d1 < max_rpe_context_length) & (d0 - d1 >= 0)),
(max_rpe_context_length, True),
)
rpe_mapping = tkw.IndexMapping(
num_iterators=3,
Expand Down Expand Up @@ -271,10 +273,8 @@ def first_loop(
mapping_dynamic_vals=(i, j),
elements_per_thread=LOAD_ELEMS_PER_THREAD_QK,
)
x_j = x_j + rpe_reg

# Layer scaling since we use log2 instead of log2
x_j = x_j * layer_scale_reg
x_j = x_j * layer_scale_reg + rpe_reg

n_kv_index = tkw.self_index(N_KV, tkl.i32)
mask = tkw.apply_expr(n_kv_index, lambda x: x < N_KV)
Expand Down Expand Up @@ -338,10 +338,8 @@ def second_loop(
mapping_dynamic_vals=(i, j),
elements_per_thread=LOAD_ELEMS_PER_THREAD_QK,
)
x_j = x_j + rpe_reg

# Layer scaling since we use log2 instead of log2
x_j = x_j * layer_scale_reg
x_j = x_j * layer_scale_reg + rpe_reg

n_kv_index = tkw.self_index(N_KV, tkl.i32)
mask = tkw.apply_expr(n_kv_index, lambda x: x < N_KV)
Expand Down
105 changes: 57 additions & 48 deletions tests/kernel/wave/attention/extend_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import math
import pytest
import torch
import math
from torch.nn import functional as F


import iree.turbine.kernel as tk
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.utils import (
Expand Down Expand Up @@ -46,6 +49,17 @@
# Reference paged attention implementation from vLLM and sglang.


def t5_rpe_masked_cond(
rpe: torch.Tensor, max_rpe_context_length: int, sequence_length: int
) -> torch.Tensor:
positions = torch.arange(sequence_length).to(device=rpe.device)
pos_diff = positions.unsqueeze(1) - positions.unsqueeze(0)
mask = ((pos_diff >= 0) & (pos_diff < max_rpe_context_length)).to(device=rpe.device)
rpe_cond = device_zeros(sequence_length, sequence_length, dtype=rpe.dtype)
rpe_cond[mask] = rpe[pos_diff[mask]]
return rpe_cond


class ScoreMod(Enum):
SoftCap = 0
RPE = 1
Expand All @@ -64,30 +78,6 @@ def context_attention_fwd(
rpe_bias: torch.Tensor = None,
score_mod: ScoreMod = ScoreMod.SoftCap,
):
def soft_cap(score, b, h, q_idx, kv_idx):
score = score / logit_cap
score = torch.tanh(score)
score = score * logit_cap
return score

zero_tensor = torch.zeros_like(rpe_bias)

def t5_rpe(score, b, h, q_idx, kv_idx):
bias = torch.where(q_idx - kv_idx >= 0, score, zero_tensor)
bias = torch.where(q_idx - kv_idx < max_len_extend, score, zero_tensor)
score = score + bias[q_idx - kv_idx]
return score

match score_mod:
case ScoreMod.SoftCap:
score_mod_fn = soft_cap
case ScoreMod.RPE:
score_mod_fn = t5_rpe
case _:
raise ValueError("Unexpectred score_mod type")

def causal(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

cu_seq_lens = [0] * (len(b_seq_len) + 1)
for i, seq_len in enumerate(b_seq_len):
Expand All @@ -96,24 +86,31 @@ def causal(b, h, q_idx, kv_idx):
for i in range(len(b_seq_len)):
start, end = cu_seq_lens[i], cu_seq_lens[i + 1]
qkv_len = end - start
block_mask = (
create_block_mask(causal, B=None, H=None, Q_LEN=qkv_len, KV_LEN=qkv_len)
if is_causal
else None
)
o_torch = (
flex_attention(
q[start:end].permute(1, 0, 2).unsqueeze(0),
k[start:end].permute(1, 0, 2).unsqueeze(0),
v[start:end].permute(1, 0, 2).unsqueeze(0),
score_mod=score_mod_fn,
enable_gqa=True,
block_mask=block_mask,
Q = q[start:end].permute(1, 0, 2)
K = k[start:end].permute(1, 0, 2)
K = K.expand(Q.shape[0], *K.shape[1:])
V = v[start:end].permute(1, 0, 2)
V = V.expand(Q.shape[0], *V.shape[1:])
dk_sqrt = math.sqrt(1.0 / Q.shape[-1])
a = torch.bmm(Q, K.transpose(-1, -2)) * dk_sqrt
if ScoreMod == ScoreMod.SoftCap:
a = a / logit_cap
a = torch.tanh(a)
a = a * logit_cap
else:
rpe_cond = t5_rpe_masked_cond(
rpe_bias,
max_rpe_context_length=max_rpe_context_length,
sequence_length=K.shape[1],
)
.squeeze(0)
.permute(1, 0, 2)
)
o[start:end] = o_torch
print(a.shape)
print(rpe_cond.shape)
rpe_cond = rpe_cond.unsqueeze(0)
rpe_cond = rpe_cond.expand(Q.shape[0], *rpe_cond.shape[1:])
a = a + rpe_cond
reference = torch.bmm(F.softmax(a, dim=-1).to(dtype=V.dtype), V)
reference = reference.squeeze(0).permute(1, 0, 2)
o[start:end] = reference

return o

Expand Down Expand Up @@ -241,8 +238,15 @@ def create_inputs(
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
max_rpe_context_length = 2
logit_cap = 30.0
rpe_bias = 5 * torch.rand(max_len_extend, dtype=torch.float32, device="cuda")

rpe_bias = device_zeros(max_rpe_context_length + 1, dtype=torch.float32)
rpe_bias.copy_(
5 * torch.rand(max_rpe_context_length + 1, dtype=torch.float32, device="cuda")
)
rpe_bias[max_rpe_context_length] = 0
print(rpe_bias)

return (
q_extend,
Expand All @@ -262,16 +266,17 @@ def create_inputs(
max_len_extend,
logit_cap,
rpe_bias,
max_rpe_context_length,
)


# TODO: Investigate errors on MI250.
@require_e2e
@require_cdna3
# @require_cdna3
@pytest.mark.parametrize("shape", get_test_shapes("extend"))
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("enable_scheduling", [False])
@pytest.mark.parametrize("is_causal", [False, True])
@pytest.mark.parametrize("is_causal", [False])
@pytest.mark.parametrize(
"mfma_variant",
[
Expand Down Expand Up @@ -308,6 +313,7 @@ def testExtendAttention(
max_len_extend,
logit_cap,
_,
_,
) = create_inputs(shape, dtype)
shape.max_seq_len = max_len_extend

Expand Down Expand Up @@ -403,11 +409,11 @@ def testExtendAttention(

# TODO: Investigate errors on MI250.
@require_e2e
@require_cdna3
# @require_cdna3
@pytest.mark.parametrize("shape", get_test_shapes("extend"))
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("enable_scheduling", [False])
@pytest.mark.parametrize("is_causal", [True])
@pytest.mark.parametrize("is_causal", [False])
@pytest.mark.parametrize(
"mfma_variant",
[
Expand Down Expand Up @@ -443,6 +449,7 @@ def testExtendRpeAttention(
max_len_extend,
logit_cap,
rpe_bias,
max_rpe_context_length,
) = create_inputs(shape, dtype)
shape.max_seq_len = max_len_extend

Expand Down Expand Up @@ -497,6 +504,7 @@ def testExtendRpeAttention(
dynamic_symbols=dynamic_symbols,
dynamic_symbols_map=dynamic_symbols_map,
):
log2e = 1.44269504089
mb_qk = extend_attention_rpe(
q_extend,
k_extend,
Expand All @@ -508,8 +516,9 @@ def testExtendRpeAttention(
b_seq_len,
b_seq_len_extend,
b_start_loc_extend,
rpe_bias,
rpe_bias * log2e,
output,
max_rpe_context_length=max_rpe_context_length,
)

if dump_generated_mlir:
Expand Down

0 comments on commit 8726a56

Please sign in to comment.