diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index e00da5f4..468cf0d1 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -10,6 +10,9 @@ import math from typing import Any, Callable, ClassVar, Optional, List, Type, Dict from dataclasses import dataclass +import sympy.functions +import sympy.functions.elementary +import sympy.functions.elementary.piecewise import torch.fx as fx import torch.utils._pytree as pytree from collections import namedtuple @@ -409,6 +412,9 @@ def _get_const(val): expr = expr.subs(idxc.subs) # Why affine, for now simply create indexing expressions. # This can easily be adapted to affine expressions later. + select_stack = [] + if isinstance(expr, sympy.Piecewise): + assert len(expr.args) == 2 and expr.args[1][1], f"Unsupported piecewise {expr}" for term in sympy.postorder_traversal(expr): match term: case sympy.Symbol(): @@ -506,6 +512,19 @@ def _get_const(val): stack.append(base) case sympy.UnevaluatedExpr(): continue + case sympy.functions.elementary.piecewise.ExprCondPair(): + cond = stack.pop() + expr = stack.pop() + select_stack.append(cond) + select_stack.append(expr) + continue + case sympy.Piecewise(): + expr = select_stack.pop() + cond = select_stack.pop() + last_expr = select_stack.pop() + last_cond = select_stack.pop() + res = arith_d.select(last_cond, last_expr, expr) + stack.append(res) case _: raise CodegenError(f"Can not handle {type(term)} : {term}") diff --git a/iree/turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py index 39606f99..ac73e538 100644 --- a/iree/turbine/kernel/wave/constraints.py +++ b/iree/turbine/kernel/wave/constraints.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import Optional, Callable from sympy import ceiling, Piecewise, floor from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence @@ -374,6 +374,7 @@ class WorkgroupConstraint(Constraint): dim: IndexExpr tile_size: IndexExpr workgroup_dim: int + apply_fn: Optional[Callable] = None primary: Optional[bool] = True def __post_init__(self): @@ -394,6 +395,8 @@ def count(self) -> IndexExpr: return ceiling(self.dim / self.tile_size) def apply(self) -> IndexSequence: + if self.apply_fn: + return IndexSequence(self.apply_fn(self.wg_dim), 1) return IndexSequence(self.wg_dim * self.tile_size, 1) @@ -409,7 +412,7 @@ def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr] raise ValueError( "Multiple constraints in the same workgroup dimension are currently not supported." ) - grid: list[IndexExpr] = [constraint.count for constraint in wg_constraints] + grid: list[IndexExpr] = [constraint.count for constraint in sorted_constraints] return grid @@ -426,12 +429,15 @@ class TilingConstraint(Constraint): dim: IndexExpr tile_size: IndexExpr induction_var: Optional[IndexExpr] = None + iters: Optional[IndexExpr] = None @property def count(self) -> IndexExpr: """ Returns an expression for the number of iterations in the loop. """ + if self.iters: + return self.iters return ceiling(self.dim / self.tile_size) def apply(self) -> IndexSequence: diff --git a/iree/turbine/kernel/wave/templates/paged_decode_attention.py b/iree/turbine/kernel/wave/templates/paged_decode_attention.py index 17a01da4..642ddf2f 100644 --- a/iree/turbine/kernel/wave/templates/paged_decode_attention.py +++ b/iree/turbine/kernel/wave/templates/paged_decode_attention.py @@ -74,17 +74,17 @@ def phase_0_constraints(): # T represents the indices of the sequence tokens. # T is dynamic and is distributed across workgroups and is tiled. - # - # Each workgroup will process: - # T = min(BLOCK, max(SEQ_LEN - WG_IDX[U] * BLOCK, 0)) tokens - # where BLOCK = ceil(SEQ_LEN / U) - # - # While T and U are related to one another, since we do not know SEQ_LEN - # we define them as symbolic aliases with different workgroup tile sizes. - # The tile size for T is set to BLOCK_T = ceil(SEQ_LEN / U) and will also - # be defined within the kernel. - constraints += [tkw.WorkgroupConstraint(T, BLOCK_T, 0, primary=False)] - constraints += [tkw.TilingConstraint(T, 1)] + count = sympy.Piecewise( + (sympy.ceiling(T / U), WORKGROUP_0 < sympy.Mod(T, U)), + (sympy.floor(T / U), True), + ) + wg_func = lambda wg: wg * sympy.floor(T / U) + sympy.Min(wg, sympy.Mod(T, U)) + constraints += [ + tkw.WorkgroupConstraint( + T, BLOCK_T, 0, apply_fn=lambda wg: wg_func(wg), primary=False + ) + ] + constraints += [tkw.TilingConstraint(T, 1, iters=count)] # BH is the kv-head index and is distributed across workgroups. # B is the query index and is distributed like BH but with a different @@ -206,16 +206,7 @@ def phase_0( # The request index is used to load the appropriate entries from the block table. req_index = tkw.read(request_indices, elements_per_thread=1) # The sequence length is used to control the bounds of the loop over T. - orig_seq_length = tkw.read(sequence_lengths, elements_per_thread=1) - - # The dimension T and its workgroup tile size BLOCK_T are both dynamic - # and set below. - tile_size = tkw.apply_expr(orig_seq_length, lambda x: sympy.ceiling(x / U)) - tkw.set_symbol(BLOCK_T, tile_size) - seq_length = tkw.apply_expr( - orig_seq_length, - lambda x: sympy.Min(BLOCK_T, sympy.Max(0, x - WORKGROUP_0 * BLOCK_T)), - ) + seq_length = tkw.read(sequence_lengths, elements_per_thread=1) tkw.set_symbol(T, seq_length) # TODO: Add if statement here in cases where T is 0 to avoid writing nans for the output. @@ -229,7 +220,6 @@ def loop( acc: tkl.Register[S, N, B, tkl.f32], ): q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - tkw.set_symbol(T, orig_seq_length) block_indices = tkw.read( block_table, elements_per_thread=1, diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index 8354d1bd..a1f7599e 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -1127,8 +1127,6 @@ def test_paged_flash_decoding(): # CHECK: func.func @phase_0 # CHECK-COUNT-2: vector.load - # CHECK: arith.maxsi - # CHECK: arith.minsi # CHECK-COUNT-2: vector.load # CHECK: scf.for # CHECK-COUNT-9: vector.maskedload