Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paged Decode Attention #387

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions iree/turbine/kernel/_support/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def trace(self, *args, **kwargs) -> Tuple[str, List[fx.Proxy]]:
subgraph_name = self.region_graph.add_subgraph("region", traced, inner_freevars)
return subgraph_name, implicit_capture

def _create_graph_input(self, name: str, type_expr=None) -> fx.Proxy:
def _create_graph_input(self, node: fx.Node, name: str, type_expr=None) -> fx.Proxy:
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
# Can use this to check where the freevar has been lifted from.
proxy.node.meta["lifted"] = None
proxy.node.meta["lifted"] = node
return proxy

def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy):
Expand All @@ -109,7 +109,9 @@ def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy):
return self.lifted_freevars[proxy]

# Otherwise, create a new input and store it.
new_proxy = self._create_graph_input(proxy.node.name, proxy.node.type)
new_proxy = self._create_graph_input(
proxy.node, proxy.node.name, proxy.node.type
)
self.lifted_freevars[proxy] = new_proxy

# Propagate freevar usage upwards.
Expand Down
3 changes: 2 additions & 1 deletion iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,13 +1075,14 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]

# This logic relies on fact out mapping is identity.
# This logic relies on fact our mapping is identity.
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
subs = {
k: index[v] for k, v in zip(iters, self.mapping.output_mapping.keys())
}
return {
k: IndexSequence.from_expr(mapping[k], subs)
for k in arg.type.symbolic_shape
if k in mapping
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
}

return index
Expand Down
34 changes: 33 additions & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import math
from typing import Any, Callable, ClassVar, Optional, List, Type, Dict
from dataclasses import dataclass
import sympy.functions
import sympy.functions.elementary
import sympy.functions.elementary.piecewise
import torch.fx as fx
import torch.utils._pytree as pytree
from collections import namedtuple
Expand Down Expand Up @@ -354,6 +357,7 @@ def _ceiling(value):
cmp = arith_d.cmpi(
arith_d.CmpIPredicate.eq, *_broadcast(value.numerator, zero)
)
zero, result = _broadcast(zero, result)
value = arith_d.select(cmp, zero, result)
else:
value = arith_d.ceildivsi(
Expand Down Expand Up @@ -408,6 +412,9 @@ def _get_const(val):
expr = expr.subs(idxc.subs)
# Why affine, for now simply create indexing expressions.
# This can easily be adapted to affine expressions later.
select_stack = []
if isinstance(expr, sympy.Piecewise):
assert len(expr.args) == 2 and expr.args[1][1], f"Unsupported piecewise {expr}"
for term in sympy.postorder_traversal(expr):
match term:
case sympy.Symbol():
Expand Down Expand Up @@ -464,11 +471,23 @@ def _get_const(val):
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
if _is_integer_like_type(rhs.type):
type = get_type_or_element_type(rhs.type)
if _is_integer_like_type(type):
res = arith_d.maxsi(*_broadcast(lhs, rhs))
else:
res = arith_d.maximumf(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.Min():
rhs = stack.pop()
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
type = get_type_or_element_type(rhs.type)
if _is_integer_like_type(type):
res = arith_d.minsi(*_broadcast(lhs, rhs))
else:
res = arith_d.minimumf(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.logic.boolalg.BooleanFalse():
res = arith_d.constant(IntegerType.get_signless(1), 0)
stack.append(res)
Expand All @@ -493,6 +512,19 @@ def _get_const(val):
stack.append(base)
case sympy.UnevaluatedExpr():
continue
case sympy.functions.elementary.piecewise.ExprCondPair():
cond = stack.pop()
expr = stack.pop()
select_stack.append(cond)
select_stack.append(expr)
continue
case sympy.Piecewise():
expr = select_stack.pop()
cond = select_stack.pop()
last_expr = select_stack.pop()
last_cond = select_stack.pop()
res = arith_d.select(last_cond, last_expr, expr)
stack.append(res)
case _:
raise CodegenError(f"Can not handle {type(term)} : {term}")

Expand Down
63 changes: 52 additions & 11 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import Optional, Callable
from sympy import ceiling, Piecewise, floor

from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence
Expand Down Expand Up @@ -374,6 +374,8 @@ class WorkgroupConstraint(Constraint):
dim: IndexExpr
tile_size: IndexExpr
workgroup_dim: int
apply_fn: Optional[Callable] = None
primary: Optional[bool] = True

def __post_init__(self):
self.wg_dim = None
Expand All @@ -393,20 +395,24 @@ def count(self) -> IndexExpr:
return ceiling(self.dim / self.tile_size)

def apply(self) -> IndexSequence:
if self.apply_fn:
return IndexSequence(self.apply_fn(self.wg_dim), 1)
return IndexSequence(self.wg_dim * self.tile_size, 1)


def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]:
sorted_constraints = sorted(wg_constraints, key=lambda x: x.workgroup_dim)
# Currently not more than one constraint in each dimension supported.
sorted_constraints = sorted(
[x for x in wg_constraints if x.primary], key=lambda x: x.workgroup_dim
)
# Currently not more than one primary constraint in each dimension supported.
if any(
sorted_constraints[i].workgroup_dim == sorted_constraints[i + 1].workgroup_dim
for i in range(len(sorted_constraints) - 1)
):
raise ValueError(
"Multiple constraints in the same workgroup dimension are currently not supported."
)
grid: list[IndexExpr] = [constraint.count for constraint in wg_constraints]
grid: list[IndexExpr] = [constraint.count for constraint in sorted_constraints]
return grid


Expand All @@ -423,12 +429,15 @@ class TilingConstraint(Constraint):
dim: IndexExpr
tile_size: IndexExpr
induction_var: Optional[IndexExpr] = None
iters: Optional[IndexExpr] = None

@property
def count(self) -> IndexExpr:
"""
Returns an expression for the number of iterations in the loop.
"""
if self.iters:
return self.iters
return ceiling(self.dim / self.tile_size)

def apply(self) -> IndexSequence:
Expand Down Expand Up @@ -482,14 +491,46 @@ def get_constrained_shape(
) -> tuple[IndexExpr]:
"""
Given a shape, workgroup and tiling constraints, returns the shape
of the distributed and tiled tensor.
of the distributed and tiled tensor. The shape is determined using the following
criteria:
0. If no workgroup or tiling constraints are provided, the original shape is used.
1. If only workgroup constraints are provided, the shape is determined by the
tile size of the workgroup constraints.
2. If only tiling constraints are provided, the shape is determined by the
tile size of the tiling constraints.
3. If both workgroup and tiling constraints are provided, the shape is determined
from the tiling constraints*.
* By choosing tiling constraints, the shared memory used will be less but we will
not be able to coalesce global memory accesses (minimize_global_loads). If instead
we choose workgroup constraints, we will be able to coalesce global memory accesses
but will use more shared memory.
We choose tiling constraints over workgroup constraints because workgroup constraints
and tiling constraints will only be used when we cannot coalesce global memory
accesses because of constraints like dynamic read indices for block tables in
paged attention.
To enable workgroup constraints instead, we will additionally need to remove induction
variables from the global read and shared write indices and ensure that they get
hoisted out of the loop.
"""
constrained_shape = list(shape)
all_same_type = lambda x, type: all(
isinstance(constraint, type) for constraint in x
)
for i, dim in enumerate(shape):
for constraint in constraints:
if isinstance(constraint, WorkgroupConstraint) or isinstance(
constraint, TilingConstraint
):
if dim == constraint.dim:
constrained_shape[i] = constraint.tile_size
dim_constraints = [
constraint
for constraint in constraints
if isinstance(constraint, (WorkgroupConstraint, TilingConstraint))
and dim == constraint.dim
]
if not dim_constraints:
continue
if all_same_type(dim_constraints, WorkgroupConstraint) or all_same_type(
dim_constraints, TilingConstraint
):
constrained_shape[i] = dim_constraints[0].tile_size
continue
constrained_shape[i] = [
x.tile_size for x in dim_constraints if isinstance(x, TilingConstraint)
][0]
return tuple(constrained_shape)
44 changes: 41 additions & 3 deletions iree/turbine/kernel/wave/expansion/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
Reshape,
GetResult,
MMA,
SetSymbol,
ApplyExpr,
Broadcast,
)
from ..._support.indexing import IndexingContext, IndexSymbol
import itertools
Expand Down Expand Up @@ -453,7 +456,7 @@ def populate_inputs(
mma_metadata.last_mma_node = True
new_nodes_to_expand.append((arg, mma_metadata))
continue
case Allocate():
case Allocate() | SetSymbol() | ApplyExpr():
alloc_metadata = deepcopy(metadata)
alloc_metadata.do_not_expand = True
new_nodes_to_expand.append((arg, alloc_metadata))
Expand All @@ -462,6 +465,7 @@ def populate_inputs(
new_nodes_to_expand.append((arg, metadata))

nodes_to_expand.extend(new_nodes_to_expand)

return nodes_to_expand


Expand Down Expand Up @@ -519,7 +523,9 @@ def expand_node(
)

# Check if the node has already been expanded, if so return early.
key = ExpansionInfo(node, get_indexed_dims(expanded_dims, node))
indexed_dims = get_indexed_dims(expanded_dims, node)

key = ExpansionInfo(node, indexed_dims)
if key in expansion_context:
update_users(node, expansion_context[key], metadata, expansion_context)
return nodes_to_expand
Expand Down Expand Up @@ -614,7 +620,33 @@ def fixup_mma_nodes(trace: CapturedTrace, expansion_context: ExpansionContext):
first.replace_all_uses_with_except(second, [exclude])


def fixup_reduction_nodes(trace: CapturedTrace, expansion_context: ExpansionContext):
def get_mma_indexed_dims(
mma: MMA,
original_indexed_dims: tuple[tuple[IndexSymbol, int]],
expansion_context: ExpansionContext,
):
dim = mma.reduction_dim
indexed_dims = None
max_reduction_dim = -1
original_indexed_dims_dict = dict(original_indexed_dims)
for key in expansion_context.expansion_context.keys():
indexed_dims_dict = dict(key.indexed_dims)
if any(
dim not in indexed_dims_dict
or indexed_dims_dict[dim] != original_indexed_dims_dict[dim]
for dim in original_indexed_dims_dict
):
continue
if key.node == mma and indexed_dims_dict[dim] > max_reduction_dim:
indexed_dims = key.indexed_dims
max_reduction_dim = indexed_dims_dict[dim]
return indexed_dims


def fixup_reduction_nodes(
trace: CapturedTrace,
expansion_context: ExpansionContext,
):
reduction_context = expansion_context.reduction_context
for reduction in trace.walk(lambda x: isinstance(get_custom(x), Reduction)):
reduction = get_custom(reduction)
Expand All @@ -629,6 +661,12 @@ def fixup_reduction_nodes(trace: CapturedTrace, expansion_context: ExpansionCont
sorted_keys = dict(sorted(reduction_info.outputs.items(), key=lambda x: x[0]))
new_outputs = []
for key in sorted_keys.values():
if key not in expansion_context and isinstance(key.node, MMA):
key = ExpansionInfo(
key.node,
get_mma_indexed_dims(key.node, key.indexed_dims, expansion_context),
)
assert key in expansion_context, f"Key not found: {key}"
new_outputs.append(expansion_context[key].fx_node)
output.update_arg("return_vals", new_outputs)

Expand Down
23 changes: 22 additions & 1 deletion iree/turbine/kernel/wave/expansion/expansion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
Write,
Reshape,
NewRegister,
ReduceOp,
MMA,
)
from ...lang.global_symbols import SHARED_ADDRESS_SPACE
import itertools
Expand All @@ -45,7 +47,7 @@ def __str__(self):


def get_dim_scaling(
constraints: Sequence[Constraint], node: fx.Node
constraints: Sequence[Constraint], node: CustomOp
) -> dict[IndexSymbol, int]:
"""Get the number of expansions for the dimensions based on the constraints for a specific node."""
dim_scaling: dict[IndexSymbol, int] = {}
Expand Down Expand Up @@ -92,6 +94,25 @@ def get_dim_scaling(
)
dim_scaling[constraint.dim] = tile_size // wave_count // vector_size

# Also include dimensions that have no constraints on them and are known.
idxc = IndexingContext.current()
is_static_dim = lambda dim: dim in idxc.subs
is_non_batch = lambda dim: node.vector_shapes[dim] > 0
not_computed = lambda dim: dim not in dim_scaling

for dim in node.indexing_dims:
if not_computed(dim) and is_static_dim(dim) and is_non_batch(dim):
dim_scaling[dim] = idxc.get_static_value(dim) // node.vector_shapes[dim]

# For reduce ops, also include the reduction dimension.
if isinstance(node, ReduceOp):
reduction_dim = node.reduction_dim
if not_computed(reduction_dim) and is_static_dim(reduction_dim):
dim_scaling[reduction_dim] = (
idxc.get_static_value(reduction_dim)
// node.vector_shapes[reduction_dim]
)

return dim_scaling


Expand Down
Loading
Loading