Skip to content

Commit

Permalink
Paged Attention v2
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jan 16, 2025
1 parent 7f22dc8 commit 5e6ef37
Show file tree
Hide file tree
Showing 14 changed files with 929 additions and 22 deletions.
8 changes: 5 additions & 3 deletions iree/turbine/kernel/_support/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def trace(self, *args, **kwargs) -> Tuple[str, List[fx.Proxy]]:
subgraph_name = self.region_graph.add_subgraph("region", traced, inner_freevars)
return subgraph_name, implicit_capture

def _create_graph_input(self, name: str, type_expr=None) -> fx.Proxy:
def _create_graph_input(self, node: fx.Node, name: str, type_expr=None) -> fx.Proxy:
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
# Can use this to check where the freevar has been lifted from.
proxy.node.meta["lifted"] = None
proxy.node.meta["lifted"] = node
return proxy

def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy):
Expand All @@ -109,7 +109,9 @@ def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy):
return self.lifted_freevars[proxy]

# Otherwise, create a new input and store it.
new_proxy = self._create_graph_input(proxy.node.name, proxy.node.type)
new_proxy = self._create_graph_input(
proxy.node, proxy.node.name, proxy.node.type
)
self.lifted_freevars[proxy] = new_proxy

# Propagate freevar usage upwards.
Expand Down
3 changes: 2 additions & 1 deletion iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,13 +1075,14 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]

# This logic relies on fact out mapping is identity.
# This logic relies on fact our mapping is identity.
subs = {
k: index[v] for k, v in zip(iters, self.mapping.output_mapping.keys())
}
return {
k: IndexSequence.from_expr(mapping[k], subs)
for k in arg.type.symbolic_shape
if k in mapping
}

return index
Expand Down
15 changes: 14 additions & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def _ceiling(value):
cmp = arith_d.cmpi(
arith_d.CmpIPredicate.eq, *_broadcast(value.numerator, zero)
)
zero, result = _broadcast(zero, result)
value = arith_d.select(cmp, zero, result)
else:
value = arith_d.ceildivsi(
Expand Down Expand Up @@ -464,11 +465,23 @@ def _get_const(val):
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
if _is_integer_like_type(rhs.type):
type = get_type_or_element_type(rhs.type)
if _is_integer_like_type(type):
res = arith_d.maxsi(*_broadcast(lhs, rhs))
else:
res = arith_d.maximumf(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.Min():
rhs = stack.pop()
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
type = get_type_or_element_type(rhs.type)
if _is_integer_like_type(type):
res = arith_d.minsi(*_broadcast(lhs, rhs))
else:
res = arith_d.minimumf(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.logic.boolalg.BooleanFalse():
res = arith_d.constant(IntegerType.get_signless(1), 0)
stack.append(res)
Expand Down
7 changes: 5 additions & 2 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ class WorkgroupConstraint(Constraint):
dim: IndexExpr
tile_size: IndexExpr
workgroup_dim: int
primary: Optional[bool] = True

def __post_init__(self):
self.wg_dim = None
Expand All @@ -397,8 +398,10 @@ def apply(self) -> IndexSequence:


def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]:
sorted_constraints = sorted(wg_constraints, key=lambda x: x.workgroup_dim)
# Currently not more than one constraint in each dimension supported.
sorted_constraints = sorted(
[x for x in wg_constraints if x.primary], key=lambda x: x.workgroup_dim
)
# Currently not more than one primary constraint in each dimension supported.
if any(
sorted_constraints[i].workgroup_dim == sorted_constraints[i + 1].workgroup_dim
for i in range(len(sorted_constraints) - 1)
Expand Down
44 changes: 41 additions & 3 deletions iree/turbine/kernel/wave/expansion/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
Reshape,
GetResult,
MMA,
SetSymbol,
ApplyExpr,
Broadcast,
)
from ..._support.indexing import IndexingContext, IndexSymbol
import itertools
Expand Down Expand Up @@ -453,7 +456,7 @@ def populate_inputs(
mma_metadata.last_mma_node = True
new_nodes_to_expand.append((arg, mma_metadata))
continue
case Allocate():
case Allocate() | SetSymbol() | ApplyExpr():
alloc_metadata = deepcopy(metadata)
alloc_metadata.do_not_expand = True
new_nodes_to_expand.append((arg, alloc_metadata))
Expand All @@ -462,6 +465,7 @@ def populate_inputs(
new_nodes_to_expand.append((arg, metadata))

nodes_to_expand.extend(new_nodes_to_expand)

return nodes_to_expand


Expand Down Expand Up @@ -519,7 +523,9 @@ def expand_node(
)

# Check if the node has already been expanded, if so return early.
key = ExpansionInfo(node, get_indexed_dims(expanded_dims, node))
indexed_dims = get_indexed_dims(expanded_dims, node)

key = ExpansionInfo(node, indexed_dims)
if key in expansion_context:
update_users(node, expansion_context[key], metadata, expansion_context)
return nodes_to_expand
Expand Down Expand Up @@ -614,7 +620,33 @@ def fixup_mma_nodes(trace: CapturedTrace, expansion_context: ExpansionContext):
first.replace_all_uses_with_except(second, [exclude])


def fixup_reduction_nodes(trace: CapturedTrace, expansion_context: ExpansionContext):
def get_mma_indexed_dims(
mma: MMA,
original_indexed_dims: tuple[tuple[IndexSymbol, int]],
expansion_context: ExpansionContext,
):
dim = mma.reduction_dim
indexed_dims = None
max_reduction_dim = -1
original_indexed_dims_dict = dict(original_indexed_dims)
for key in expansion_context.expansion_context.keys():
indexed_dims_dict = dict(key.indexed_dims)
if any(
dim not in indexed_dims_dict
or indexed_dims_dict[dim] != original_indexed_dims_dict[dim]
for dim in original_indexed_dims_dict
):
continue
if key.node == mma and indexed_dims_dict[dim] > max_reduction_dim:
indexed_dims = key.indexed_dims
max_reduction_dim = indexed_dims_dict[dim]
return indexed_dims


def fixup_reduction_nodes(
trace: CapturedTrace,
expansion_context: ExpansionContext,
):
reduction_context = expansion_context.reduction_context
for reduction in trace.walk(lambda x: isinstance(get_custom(x), Reduction)):
reduction = get_custom(reduction)
Expand All @@ -629,6 +661,12 @@ def fixup_reduction_nodes(trace: CapturedTrace, expansion_context: ExpansionCont
sorted_keys = dict(sorted(reduction_info.outputs.items(), key=lambda x: x[0]))
new_outputs = []
for key in sorted_keys.values():
if key not in expansion_context and isinstance(key.node, MMA):
key = ExpansionInfo(
key.node,
get_mma_indexed_dims(key.node, key.indexed_dims, expansion_context),
)
assert key in expansion_context, f"Key not found: {key}"
new_outputs.append(expansion_context[key].fx_node)
output.update_arg("return_vals", new_outputs)

Expand Down
23 changes: 22 additions & 1 deletion iree/turbine/kernel/wave/expansion/expansion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
Write,
Reshape,
NewRegister,
ReduceOp,
MMA,
)
from ...lang.global_symbols import SHARED_ADDRESS_SPACE
import itertools
Expand All @@ -45,7 +47,7 @@ def __str__(self):


def get_dim_scaling(
constraints: Sequence[Constraint], node: fx.Node
constraints: Sequence[Constraint], node: CustomOp
) -> dict[IndexSymbol, int]:
"""Get the number of expansions for the dimensions based on the constraints for a specific node."""
dim_scaling: dict[IndexSymbol, int] = {}
Expand Down Expand Up @@ -92,6 +94,25 @@ def get_dim_scaling(
)
dim_scaling[constraint.dim] = tile_size // wave_count // vector_size

# Also include dimensions that have no constraints on them and are known.
idxc = IndexingContext.current()
is_static_dim = lambda dim: dim in idxc.subs
is_non_batch = lambda dim: node.vector_shapes[dim] > 0
not_computed = lambda dim: dim not in dim_scaling

for dim in node.indexing_dims:
if not_computed(dim) and is_static_dim(dim) and is_non_batch(dim):
dim_scaling[dim] = idxc.get_static_value(dim) // node.vector_shapes[dim]

# For reduce ops, also include the reduction dimension.
if isinstance(node, ReduceOp):
reduction_dim = node.reduction_dim
if not_computed(reduction_dim) and is_static_dim(reduction_dim):
dim_scaling[reduction_dim] = (
idxc.get_static_value(reduction_dim)
// node.vector_shapes[reduction_dim]
)

return dim_scaling


Expand Down
16 changes: 14 additions & 2 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def _get_symbolic_shape_and_vector_shapes(
# as it includes the aliased variables.
symbolic_shape = custom.register_type.symbolic_shape
vector_shapes = custom.vector_shapes
# TODO: Remove the need for this.
return custom.memory_type.symbolic_shape, vector_shapes
if any([x in custom.memory_type.symbolic_shape for x in aliases]):
symbolic_shape = custom.memory_type.symbolic_shape
return symbolic_shape, vector_shapes
Expand Down Expand Up @@ -334,7 +336,7 @@ def set_derived_index(trace):
worklist.append((inp, new_index))


def verify_nodes(trace: CapturedTrace):
def verify_nodes(trace: CapturedTrace, constraints: list[Constraint]):
"""
Verify that all the valid nodes have their index and vector shapes set.
"""
Expand All @@ -348,6 +350,16 @@ def verify_nodes(trace: CapturedTrace):
if isinstance(custom, (Output, Reduction)):
continue
assert custom.index, f"Index not set for node {custom.fx_node}"
if not custom.vector_shapes:
# If vector_shapes is not set, see if it can be derived from the hardware constraints.
hw_constraint = get_hardware_constraint(constraints)
update_vector_shapes = [
dim for dim in custom.index if dim in hw_constraint.vector_shapes
]
if update_vector_shapes:
custom.vector_shapes = {}
for dim in update_vector_shapes:
custom.vector_shapes[dim] = hw_constraint.vector_shapes[dim]
assert custom.vector_shapes, f"Vector shapes not set for node {custom.fx_node}"


Expand All @@ -357,7 +369,7 @@ def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]):
set_thread_dependent_index(constraints, mma_index, trace)
set_derived_index(trace)
resolve_thread_shapes(trace, constraints)
verify_nodes(trace)
verify_nodes(trace, constraints)


def compute_stride(
Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/symbolic_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from iree.turbine.kernel._support.indexing import IndexExpr, IndexSymbol
from dataclasses import dataclass
from typing import Callable
from typing import Callable, Optional, Sequence
from .utils import subs_idxc
from .constraints import (
Constraint,
Expand Down
Loading

0 comments on commit 5e6ef37

Please sign in to comment.