Skip to content

Commit

Permalink
Update handling on aligned lengths
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jan 17, 2025
1 parent 5779cd1 commit 36a3fe2
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 26 deletions.
19 changes: 19 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import math
from typing import Any, Callable, ClassVar, Optional, List, Type, Dict
from dataclasses import dataclass
import sympy.functions
import sympy.functions.elementary
import sympy.functions.elementary.piecewise
import torch.fx as fx
import torch.utils._pytree as pytree
from collections import namedtuple
Expand Down Expand Up @@ -409,6 +412,9 @@ def _get_const(val):
expr = expr.subs(idxc.subs)
# Why affine, for now simply create indexing expressions.
# This can easily be adapted to affine expressions later.
select_stack = []
if isinstance(expr, sympy.Piecewise):
assert len(expr.args) == 2 and expr.args[1][1], f"Unsupported piecewise {expr}"
for term in sympy.postorder_traversal(expr):
match term:
case sympy.Symbol():
Expand Down Expand Up @@ -506,6 +512,19 @@ def _get_const(val):
stack.append(base)
case sympy.UnevaluatedExpr():
continue
case sympy.functions.elementary.piecewise.ExprCondPair():
cond = stack.pop()
expr = stack.pop()
select_stack.append(cond)
select_stack.append(expr)
continue
case sympy.Piecewise():
expr = select_stack.pop()
cond = select_stack.pop()
last_expr = select_stack.pop()
last_cond = select_stack.pop()
res = arith_d.select(last_cond, last_expr, expr)
stack.append(res)
case _:
raise CodegenError(f"Can not handle {type(term)} : {term}")

Expand Down
10 changes: 8 additions & 2 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import Optional, Callable
from sympy import ceiling, Piecewise, floor

from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence
Expand Down Expand Up @@ -374,6 +374,7 @@ class WorkgroupConstraint(Constraint):
dim: IndexExpr
tile_size: IndexExpr
workgroup_dim: int
apply_fn: Optional[Callable] = None
primary: Optional[bool] = True

def __post_init__(self):
Expand All @@ -394,6 +395,8 @@ def count(self) -> IndexExpr:
return ceiling(self.dim / self.tile_size)

def apply(self) -> IndexSequence:
if self.apply_fn:
return IndexSequence(self.apply_fn(self.wg_dim), 1)
return IndexSequence(self.wg_dim * self.tile_size, 1)


Expand All @@ -409,7 +412,7 @@ def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]
raise ValueError(
"Multiple constraints in the same workgroup dimension are currently not supported."
)
grid: list[IndexExpr] = [constraint.count for constraint in wg_constraints]
grid: list[IndexExpr] = [constraint.count for constraint in sorted_constraints]
return grid


Expand All @@ -426,12 +429,15 @@ class TilingConstraint(Constraint):
dim: IndexExpr
tile_size: IndexExpr
induction_var: Optional[IndexExpr] = None
iters: Optional[IndexExpr] = None

@property
def count(self) -> IndexExpr:
"""
Returns an expression for the number of iterations in the loop.
"""
if self.iters:
return self.iters
return ceiling(self.dim / self.tile_size)

def apply(self) -> IndexSequence:
Expand Down
34 changes: 12 additions & 22 deletions iree/turbine/kernel/wave/templates/paged_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,17 @@ def phase_0_constraints():

# T represents the indices of the sequence tokens.
# T is dynamic and is distributed across workgroups and is tiled.
#
# Each workgroup will process:
# T = min(BLOCK, max(SEQ_LEN - WG_IDX[U] * BLOCK, 0)) tokens
# where BLOCK = ceil(SEQ_LEN / U)
#
# While T and U are related to one another, since we do not know SEQ_LEN
# we define them as symbolic aliases with different workgroup tile sizes.
# The tile size for T is set to BLOCK_T = ceil(SEQ_LEN / U) and will also
# be defined within the kernel.
constraints += [tkw.WorkgroupConstraint(T, BLOCK_T, 0, primary=False)]
constraints += [tkw.TilingConstraint(T, 1)]
count = sympy.Piecewise(
(sympy.ceiling(T / U), WORKGROUP_0 < sympy.Mod(T, U)),
(sympy.floor(T / U), True),
)
wg_func = lambda wg: wg * sympy.floor(T / U) + sympy.Min(wg, sympy.Mod(T, U))
constraints += [
tkw.WorkgroupConstraint(
T, BLOCK_T, 0, apply_fn=lambda wg: wg_func(wg), primary=False
)
]
constraints += [tkw.TilingConstraint(T, 1, iters=count)]

# BH is the kv-head index and is distributed across workgroups.
# B is the query index and is distributed like BH but with a different
Expand Down Expand Up @@ -206,16 +206,7 @@ def phase_0(
# The request index is used to load the appropriate entries from the block table.
req_index = tkw.read(request_indices, elements_per_thread=1)
# The sequence length is used to control the bounds of the loop over T.
orig_seq_length = tkw.read(sequence_lengths, elements_per_thread=1)

# The dimension T and its workgroup tile size BLOCK_T are both dynamic
# and set below.
tile_size = tkw.apply_expr(orig_seq_length, lambda x: sympy.ceiling(x / U))
tkw.set_symbol(BLOCK_T, tile_size)
seq_length = tkw.apply_expr(
orig_seq_length,
lambda x: sympy.Min(BLOCK_T, sympy.Max(0, x - WORKGROUP_0 * BLOCK_T)),
)
seq_length = tkw.read(sequence_lengths, elements_per_thread=1)
tkw.set_symbol(T, seq_length)

# TODO: Add if statement here in cases where T is 0 to avoid writing nans for the output.
Expand All @@ -229,7 +220,6 @@ def loop(
acc: tkl.Register[S, N, B, tkl.f32],
):
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
tkw.set_symbol(T, orig_seq_length)
block_indices = tkw.read(
block_table,
elements_per_thread=1,
Expand Down
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,8 +1127,6 @@ def test_paged_flash_decoding():

# CHECK: func.func @phase_0
# CHECK-COUNT-2: vector.load
# CHECK: arith.maxsi
# CHECK: arith.minsi
# CHECK-COUNT-2: vector.load
# CHECK: scf.for
# CHECK-COUNT-9: vector.maskedload
Expand Down

0 comments on commit 36a3fe2

Please sign in to comment.