Skip to content

Commit

Permalink
Update handling on aligned lengths
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jan 17, 2025
1 parent 8da3e95 commit f874ef3
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 32 deletions.
19 changes: 19 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import math
from typing import Any, Callable, ClassVar, Optional, List, Type, Dict
from dataclasses import dataclass
import sympy.functions
import sympy.functions.elementary
import sympy.functions.elementary.piecewise
import torch.fx as fx
import torch.utils._pytree as pytree
from collections import namedtuple
Expand Down Expand Up @@ -409,6 +412,9 @@ def _get_const(val):
expr = expr.subs(idxc.subs)
# Why affine, for now simply create indexing expressions.
# This can easily be adapted to affine expressions later.
select_stack = []
if isinstance(expr, sympy.Piecewise):
assert len(expr.args) == 2 and expr.args[1][1], f"Unsupported piecewise {expr}"
for term in sympy.postorder_traversal(expr):
match term:
case sympy.Symbol():
Expand Down Expand Up @@ -506,6 +512,19 @@ def _get_const(val):
stack.append(base)
case sympy.UnevaluatedExpr():
continue
case sympy.functions.elementary.piecewise.ExprCondPair():
cond = stack.pop()
expr = stack.pop()
select_stack.append(cond)
select_stack.append(expr)
continue
case sympy.Piecewise():
expr = select_stack.pop()
cond = select_stack.pop()
last_expr = select_stack.pop()
last_cond = select_stack.pop()
res = arith_d.select(last_cond, last_expr, expr)
stack.append(res)
case _:
raise CodegenError(f"Can not handle {type(term)} : {term}")

Expand Down
10 changes: 8 additions & 2 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import Optional, Callable
from sympy import ceiling, Piecewise, floor

from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence
Expand Down Expand Up @@ -374,6 +374,7 @@ class WorkgroupConstraint(Constraint):
dim: IndexExpr
tile_size: IndexExpr
workgroup_dim: int
apply_fn: Optional[Callable] = None
primary: Optional[bool] = True

def __post_init__(self):
Expand All @@ -394,6 +395,8 @@ def count(self) -> IndexExpr:
return ceiling(self.dim / self.tile_size)

def apply(self) -> IndexSequence:
if self.apply_fn:
return IndexSequence(self.apply_fn(self.wg_dim), 1)
return IndexSequence(self.wg_dim * self.tile_size, 1)


Expand All @@ -409,7 +412,7 @@ def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]
raise ValueError(
"Multiple constraints in the same workgroup dimension are currently not supported."
)
grid: list[IndexExpr] = [constraint.count for constraint in wg_constraints]
grid: list[IndexExpr] = [constraint.count for constraint in sorted_constraints]
return grid


Expand All @@ -426,12 +429,15 @@ class TilingConstraint(Constraint):
dim: IndexExpr
tile_size: IndexExpr
induction_var: Optional[IndexExpr] = None
iters: Optional[IndexExpr] = None

@property
def count(self) -> IndexExpr:
"""
Returns an expression for the number of iterations in the loop.
"""
if self.iters:
return self.iters
return ceiling(self.dim / self.tile_size)

def apply(self) -> IndexSequence:
Expand Down
36 changes: 13 additions & 23 deletions iree/turbine/kernel/wave/templates/paged_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,17 @@ def phase_0_constraints():

# T represents the indices of the sequence tokens.
# T is dynamic and is distributed across workgroups and is tiled.
#
# Each workgroup will process:
# T = min(BLOCK, max(SEQ_LEN - WG_IDX[U] * BLOCK, 0)) tokens
# where BLOCK = ceil(SEQ_LEN / U)
#
# While T and U are related to one another, since we do not know SEQ_LEN
# we define them as symbolic aliases with different workgroup tile sizes.
# The tile size for T is set to BLOCK_T = ceil(SEQ_LEN / U) and will also
# be defined within the kernel.
constraints += [tkw.WorkgroupConstraint(T, BLOCK_T, 0, primary=False)]
constraints += [tkw.TilingConstraint(T, 1)]
count = sympy.Piecewise(
(sympy.ceiling(T / U), WORKGROUP_0 < sympy.Mod(T, U)),
(sympy.floor(T / U), True),
)
wg_func = lambda wg: wg * sympy.floor(T / U) + sympy.Min(wg, sympy.Mod(T, U))
constraints += [
tkw.WorkgroupConstraint(
T, BLOCK_T, 0, apply_fn=lambda wg: wg_func(wg), primary=False
)
]
constraints += [tkw.TilingConstraint(T, 1, iters=count)]

# BH is the kv-head index and is distributed across workgroups.
# B is the query index and is distributed like BH but with a different
Expand Down Expand Up @@ -206,16 +206,7 @@ def phase_0(
# The request index is used to load the appropriate entries from the block table.
req_index = tkw.read(request_indices, elements_per_thread=1)
# The sequence length is used to control the bounds of the loop over T.
orig_seq_length = tkw.read(sequence_lengths, elements_per_thread=1)

# The dimension T and its workgroup tile size BLOCK_T are both dynamic
# and set below.
tile_size = tkw.apply_expr(orig_seq_length, lambda x: sympy.ceiling(x / U))
tkw.set_symbol(BLOCK_T, tile_size)
seq_length = tkw.apply_expr(
orig_seq_length,
lambda x: sympy.Min(BLOCK_T, sympy.Max(0, x - WORKGROUP_0 * BLOCK_T)),
)
seq_length = tkw.read(sequence_lengths, elements_per_thread=1)
tkw.set_symbol(T, seq_length)

# TODO: Add if statement here in cases where T is 0 to avoid writing nans for the output.
Expand All @@ -229,7 +220,6 @@ def loop(
acc: tkl.Register[S, N, B, tkl.f32],
):
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
tkw.set_symbol(T, orig_seq_length)
block_indices = tkw.read(
block_table,
elements_per_thread=1,
Expand Down Expand Up @@ -325,7 +315,7 @@ def repeat(

dynamic_symbols_0 = [T]
dynamic_symbols_1 = []
dynamic_symbols_map_0 = {T: 1}
dynamic_symbols_map_0 = {T: shape[7]}
dynamic_symbols_map_1 = {}

return (
Expand Down
4 changes: 1 addition & 3 deletions lit_tests/kernel/wave/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def repeat(
@run_test
def test_paged_flash_decoding():
# (B, M, N, K1, K2, BH, S)
shape = (128, 1, 32, 32, 64, 4, 8)
shape = (128, 1, 32, 32, 64, 4, 8, 128)
max_tokens = 2048
num_kv_splits = 8
mfma_variant = tkw.MMAType.F32_16x16x16_F16
Expand Down Expand Up @@ -1127,8 +1127,6 @@ def test_paged_flash_decoding():

# CHECK: func.func @phase_0
# CHECK-COUNT-2: vector.load
# CHECK: arith.maxsi
# CHECK: arith.minsi
# CHECK-COUNT-2: vector.load
# CHECK: scf.for
# CHECK-COUNT-9: vector.maskedload
Expand Down
12 changes: 8 additions & 4 deletions tests/kernel/wave/attention/paged_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.testing import assert_allclose
from ..common.utils import (
require_e2e,
require_cdna3,
enable_scheduling_barriers,
dump_generated_mlir,
)
Expand All @@ -41,8 +42,7 @@
NUM_BLOCKS = [128]
# First item is query length, second item is key/value length.
# In decode, query length is always one.
# TODO: Check with more queries and unaligned shapes.
SEQ_LENS = [[(1, 16), (1, 8)]]
SEQ_LENS = [[(1, 19), (1, 34), (1, 27)]]


# From: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/torch_native_backend.py
Expand Down Expand Up @@ -173,7 +173,9 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0)


# TODO: Investigate errors on MI250.
@require_e2e
@require_cdna3
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand Down Expand Up @@ -223,7 +225,9 @@ def testPagedFlashDecoding(
# TODO: The block table entries should be able to be a random number
# in the range [0, num_blocks * block_size), but that fails for now.
# As a workaround, the maximum value is set to num_seqs - 1.
block_table = device_randint(0, num_seqs, (num_seqs, max_kv_len), dtype=torch.int32)
block_table = device_randint(
0, num_blocks, (num_seqs, num_blocks), dtype=torch.int32
)
request_indices = device_arange(num_seqs, dtype=torch.int32)
kv_lens_tensor = device_zeros(num_seqs, dtype=torch.int32)
for i in range(len(kv_lens)):
Expand All @@ -238,7 +242,7 @@ def testPagedFlashDecoding(
M = 1
N = head_size
BH = num_kv_heads
shape = (B, M, N, K1, K2, BH, S)
shape = (B, M, N, K1, K2, BH, S, num_blocks)
num_kv_splits = 8
(
phase_0,
Expand Down

0 comments on commit f874ef3

Please sign in to comment.