Skip to content

Commit

Permalink
Paged Attention v2
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jan 17, 2025
1 parent 3287dea commit 8da3e95
Show file tree
Hide file tree
Showing 14 changed files with 982 additions and 42 deletions.
8 changes: 5 additions & 3 deletions iree/turbine/kernel/_support/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def trace(self, *args, **kwargs) -> Tuple[str, List[fx.Proxy]]:
subgraph_name = self.region_graph.add_subgraph("region", traced, inner_freevars)
return subgraph_name, implicit_capture

def _create_graph_input(self, name: str, type_expr=None) -> fx.Proxy:
def _create_graph_input(self, node: fx.Node, name: str, type_expr=None) -> fx.Proxy:
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
# Can use this to check where the freevar has been lifted from.
proxy.node.meta["lifted"] = None
proxy.node.meta["lifted"] = node
return proxy

def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy):
Expand All @@ -109,7 +109,9 @@ def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy):
return self.lifted_freevars[proxy]

# Otherwise, create a new input and store it.
new_proxy = self._create_graph_input(proxy.node.name, proxy.node.type)
new_proxy = self._create_graph_input(
proxy.node, proxy.node.name, proxy.node.type
)
self.lifted_freevars[proxy] = new_proxy

# Propagate freevar usage upwards.
Expand Down
3 changes: 2 additions & 1 deletion iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,13 +1075,14 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]

# This logic relies on fact out mapping is identity.
# This logic relies on fact our mapping is identity.
subs = {
k: index[v] for k, v in zip(iters, self.mapping.output_mapping.keys())
}
return {
k: IndexSequence.from_expr(mapping[k], subs)
for k in arg.type.symbolic_shape
if k in mapping
}

return index
Expand Down
15 changes: 14 additions & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def _ceiling(value):
cmp = arith_d.cmpi(
arith_d.CmpIPredicate.eq, *_broadcast(value.numerator, zero)
)
zero, result = _broadcast(zero, result)
value = arith_d.select(cmp, zero, result)
else:
value = arith_d.ceildivsi(
Expand Down Expand Up @@ -464,11 +465,23 @@ def _get_const(val):
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
if _is_integer_like_type(rhs.type):
type = get_type_or_element_type(rhs.type)
if _is_integer_like_type(type):
res = arith_d.maxsi(*_broadcast(lhs, rhs))
else:
res = arith_d.maximumf(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.Min():
rhs = stack.pop()
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
type = get_type_or_element_type(rhs.type)
if _is_integer_like_type(type):
res = arith_d.minsi(*_broadcast(lhs, rhs))
else:
res = arith_d.minimumf(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.logic.boolalg.BooleanFalse():
res = arith_d.constant(IntegerType.get_signless(1), 0)
stack.append(res)
Expand Down
53 changes: 44 additions & 9 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ class WorkgroupConstraint(Constraint):
dim: IndexExpr
tile_size: IndexExpr
workgroup_dim: int
primary: Optional[bool] = True

def __post_init__(self):
self.wg_dim = None
Expand All @@ -397,8 +398,10 @@ def apply(self) -> IndexSequence:


def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]:
sorted_constraints = sorted(wg_constraints, key=lambda x: x.workgroup_dim)
# Currently not more than one constraint in each dimension supported.
sorted_constraints = sorted(
[x for x in wg_constraints if x.primary], key=lambda x: x.workgroup_dim
)
# Currently not more than one primary constraint in each dimension supported.
if any(
sorted_constraints[i].workgroup_dim == sorted_constraints[i + 1].workgroup_dim
for i in range(len(sorted_constraints) - 1)
Expand Down Expand Up @@ -482,14 +485,46 @@ def get_constrained_shape(
) -> tuple[IndexExpr]:
"""
Given a shape, workgroup and tiling constraints, returns the shape
of the distributed and tiled tensor.
of the distributed and tiled tensor. The shape is determined using the following
criteria:
0. If no workgroup or tiling constraints are provided, the original shape is used.
1. If only workgroup constraints are provided, the shape is determined by the
tile size of the workgroup constraints.
2. If only tiling constraints are provided, the shape is determined by the
tile size of the tiling constraints.
3. If both workgroup and tiling constraints are provided, the shape is determined
from the tiling constraints*.
* By choosing tiling constraints, the shared memory used will be less but we will
not be able to coalesce global memory accesses (minimize_global_loads). If instead
we choose workgroup constraints, we will be able to coalesce global memory accesses
but will use more shared memory.
We choose tiling constraints over workgroup constraints because workgroup constraints
and tiling constraints will only be used when we cannot coalesce global memory
accesses because of constraints like dynamic read indices for block tables in
paged attention.
To enable workgroup constraints instead, we will additionally need to remove induction
variables from the global read and shared write indices and ensure that they get
hoisted out of the loop.
"""
constrained_shape = list(shape)
all_same_type = lambda x, type: all(
isinstance(constraint, type) for constraint in x
)
for i, dim in enumerate(shape):
for constraint in constraints:
if isinstance(constraint, WorkgroupConstraint) or isinstance(
constraint, TilingConstraint
):
if dim == constraint.dim:
constrained_shape[i] = constraint.tile_size
dim_constraints = [
constraint
for constraint in constraints
if isinstance(constraint, (WorkgroupConstraint, TilingConstraint))
and dim == constraint.dim
]
if not dim_constraints:
continue
if all_same_type(dim_constraints, WorkgroupConstraint) or all_same_type(
dim_constraints, TilingConstraint
):
constrained_shape[i] = dim_constraints[0].tile_size
continue
constrained_shape[i] = [
x.tile_size for x in dim_constraints if isinstance(x, TilingConstraint)
][0]
return tuple(constrained_shape)
44 changes: 41 additions & 3 deletions iree/turbine/kernel/wave/expansion/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
Reshape,
GetResult,
MMA,
SetSymbol,
ApplyExpr,
Broadcast,
)
from ..._support.indexing import IndexingContext, IndexSymbol
import itertools
Expand Down Expand Up @@ -453,7 +456,7 @@ def populate_inputs(
mma_metadata.last_mma_node = True
new_nodes_to_expand.append((arg, mma_metadata))
continue
case Allocate():
case Allocate() | SetSymbol() | ApplyExpr():
alloc_metadata = deepcopy(metadata)
alloc_metadata.do_not_expand = True
new_nodes_to_expand.append((arg, alloc_metadata))
Expand All @@ -462,6 +465,7 @@ def populate_inputs(
new_nodes_to_expand.append((arg, metadata))

nodes_to_expand.extend(new_nodes_to_expand)

return nodes_to_expand


Expand Down Expand Up @@ -519,7 +523,9 @@ def expand_node(
)

# Check if the node has already been expanded, if so return early.
key = ExpansionInfo(node, get_indexed_dims(expanded_dims, node))
indexed_dims = get_indexed_dims(expanded_dims, node)

key = ExpansionInfo(node, indexed_dims)
if key in expansion_context:
update_users(node, expansion_context[key], metadata, expansion_context)
return nodes_to_expand
Expand Down Expand Up @@ -614,7 +620,33 @@ def fixup_mma_nodes(trace: CapturedTrace, expansion_context: ExpansionContext):
first.replace_all_uses_with_except(second, [exclude])


def fixup_reduction_nodes(trace: CapturedTrace, expansion_context: ExpansionContext):
def get_mma_indexed_dims(
mma: MMA,
original_indexed_dims: tuple[tuple[IndexSymbol, int]],
expansion_context: ExpansionContext,
):
dim = mma.reduction_dim
indexed_dims = None
max_reduction_dim = -1
original_indexed_dims_dict = dict(original_indexed_dims)
for key in expansion_context.expansion_context.keys():
indexed_dims_dict = dict(key.indexed_dims)
if any(
dim not in indexed_dims_dict
or indexed_dims_dict[dim] != original_indexed_dims_dict[dim]
for dim in original_indexed_dims_dict
):
continue
if key.node == mma and indexed_dims_dict[dim] > max_reduction_dim:
indexed_dims = key.indexed_dims
max_reduction_dim = indexed_dims_dict[dim]
return indexed_dims


def fixup_reduction_nodes(
trace: CapturedTrace,
expansion_context: ExpansionContext,
):
reduction_context = expansion_context.reduction_context
for reduction in trace.walk(lambda x: isinstance(get_custom(x), Reduction)):
reduction = get_custom(reduction)
Expand All @@ -629,6 +661,12 @@ def fixup_reduction_nodes(trace: CapturedTrace, expansion_context: ExpansionCont
sorted_keys = dict(sorted(reduction_info.outputs.items(), key=lambda x: x[0]))
new_outputs = []
for key in sorted_keys.values():
if key not in expansion_context and isinstance(key.node, MMA):
key = ExpansionInfo(
key.node,
get_mma_indexed_dims(key.node, key.indexed_dims, expansion_context),
)
assert key in expansion_context, f"Key not found: {key}"
new_outputs.append(expansion_context[key].fx_node)
output.update_arg("return_vals", new_outputs)

Expand Down
23 changes: 22 additions & 1 deletion iree/turbine/kernel/wave/expansion/expansion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
Write,
Reshape,
NewRegister,
ReduceOp,
MMA,
)
from ...lang.global_symbols import SHARED_ADDRESS_SPACE
import itertools
Expand All @@ -45,7 +47,7 @@ def __str__(self):


def get_dim_scaling(
constraints: Sequence[Constraint], node: fx.Node
constraints: Sequence[Constraint], node: CustomOp
) -> dict[IndexSymbol, int]:
"""Get the number of expansions for the dimensions based on the constraints for a specific node."""
dim_scaling: dict[IndexSymbol, int] = {}
Expand Down Expand Up @@ -92,6 +94,25 @@ def get_dim_scaling(
)
dim_scaling[constraint.dim] = tile_size // wave_count // vector_size

# Also include dimensions that have no constraints on them and are known.
idxc = IndexingContext.current()
is_static_dim = lambda dim: dim in idxc.subs
is_non_batch = lambda dim: node.vector_shapes[dim] > 0
not_computed = lambda dim: dim not in dim_scaling

for dim in node.indexing_dims:
if not_computed(dim) and is_static_dim(dim) and is_non_batch(dim):
dim_scaling[dim] = idxc.get_static_value(dim) // node.vector_shapes[dim]

# For reduce ops, also include the reduction dimension.
if isinstance(node, ReduceOp):
reduction_dim = node.reduction_dim
if not_computed(reduction_dim) and is_static_dim(reduction_dim):
dim_scaling[reduction_dim] = (
idxc.get_static_value(reduction_dim)
// node.vector_shapes[reduction_dim]
)

return dim_scaling


Expand Down
37 changes: 22 additions & 15 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,17 @@ def get_vector_shape(

def _get_symbolic_shape_and_vector_shapes(
custom: CustomOp,
aliases: dict[IndexSymbol, SymbolicAlias],
hw_constraint: HardwareConstraint,
):
# When the memory type has symbolic aliases, use the memory type
# as it includes the aliased variables.
symbolic_shape = custom.register_type.symbolic_shape
register_shape = custom.register_type.symbolic_shape
vector_shapes = custom.vector_shapes
if any([x in custom.memory_type.symbolic_shape for x in aliases]):
symbolic_shape = custom.memory_type.symbolic_shape
return symbolic_shape, vector_shapes
memory_shape = custom.memory_type.symbolic_shape
# Check to see if the memory shape does not match with the vector shapes.
if not set(memory_shape).issubset(set(vector_shapes.keys())):
return register_shape, vector_shapes
# Pick the shape with the most dimensions.
if len(memory_shape) > len(register_shape):
return memory_shape, vector_shapes
return register_shape, vector_shapes


def partition_strided_operators(trace: CapturedTrace, constraints: list[Constraint]):
Expand Down Expand Up @@ -110,18 +111,14 @@ def has_strided_access(node: fx.Node) -> bool:
return False

strided_operators = trace.walk(has_strided_access)
hw_constraint = [c for c in constraints if isinstance(c, HardwareConstraint)][0]
aliases = {c.source: c for c in constraints if isinstance(c, SymbolicAlias)}
for operator in strided_operators:
custom = get_custom(operator)
simplified_index = {
dim: simplify_index(custom.register_index.get(dim, custom.index[dim]))
for dim in custom.index
}

symbolic_shape, vector_shapes = _get_symbolic_shape_and_vector_shapes(
custom, aliases, hw_constraint
)
symbolic_shape, vector_shapes = _get_symbolic_shape_and_vector_shapes(custom)

shape = get_vector_shape(vector_shapes, symbolic_shape)
elements_per_thread = subs_idxc(custom.elements_per_thread)
Expand Down Expand Up @@ -334,7 +331,7 @@ def set_derived_index(trace):
worklist.append((inp, new_index))


def verify_nodes(trace: CapturedTrace):
def verify_nodes(trace: CapturedTrace, constraints: list[Constraint]):
"""
Verify that all the valid nodes have their index and vector shapes set.
"""
Expand All @@ -348,6 +345,16 @@ def verify_nodes(trace: CapturedTrace):
if isinstance(custom, (Output, Reduction)):
continue
assert custom.index, f"Index not set for node {custom.fx_node}"
if not custom.vector_shapes:
# If vector_shapes is not set, see if it can be derived from the hardware constraints.
hw_constraint = get_hardware_constraint(constraints)
update_vector_shapes = [
dim for dim in custom.index if dim in hw_constraint.vector_shapes
]
if update_vector_shapes:
custom.vector_shapes = {}
for dim in update_vector_shapes:
custom.vector_shapes[dim] = hw_constraint.vector_shapes[dim]
assert custom.vector_shapes, f"Vector shapes not set for node {custom.fx_node}"


Expand All @@ -357,7 +364,7 @@ def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]):
set_thread_dependent_index(constraints, mma_index, trace)
set_derived_index(trace)
resolve_thread_shapes(trace, constraints)
verify_nodes(trace)
verify_nodes(trace, constraints)


def compute_stride(
Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/symbolic_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from iree.turbine.kernel._support.indexing import IndexExpr, IndexSymbol
from dataclasses import dataclass
from typing import Callable
from typing import Callable, Optional, Sequence
from .utils import subs_idxc
from .constraints import (
Constraint,
Expand Down
Loading

0 comments on commit 8da3e95

Please sign in to comment.