Skip to content

Commit

Permalink
[TKW] Implement support for multiple iter args on Reduction (iree-or…
Browse files Browse the repository at this point in the history
…g#166)

The main motivation behind this PR is to enable multiple induction
variable/iterArg on the same tiled "Reduction" loop. To enable above we
did a couple things:

1. Enable lowering/expansion on `operator.getitem` (the op that extract
multiple results in python i.e `res0, res1 = fn`) by templating it
on`GetResult(CustomOp)` since they have the same args and interface and
can reuse most of the indexing/expansion helper.

2. Introduce `res_idx`, a variable to represent which result index of an
op we are referring to, during expansion and context map. This is useful
for ops that has more than one results / variables as outputs.

3. bug fix in expand_reduction, where we hoist out iterating and
expanding of `reduction.init_args` out of the loop that iterates and
expands over the `yield`/`return_val` of the reduction loop. It is
expected that the size of `init_args` is the same as size of
`yield`/`return_val`. Hence if we had N iter_args/yields, we ended up
expanding the `init_args` N x N time instead of N times. We haven't seen
it thus far because we have been only playing with 1 init_arg/iterArg,
and 1x1 == 1.

4. Introduce a canonicalization pattern to fold chains of GetResult.
this is because GetResult by semantic/design is only expected to extract
and have one result. Hence a chain of GetResult should just be replaced
by itself. This help clean up the IR.

num.4 also helps circumvent issue where Reduction and GetResult is
expanded completely by itself not following the DFS structure per
dimension like the rest of the expansion code. This becomes especially
problematic for multiple IterArg since Getitem is not expecting its'
source value to be expanded without it.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored and IanNod committed Sep 30, 2024
1 parent 99b4339 commit e2cc0b9
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 45 deletions.
105 changes: 105 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,111 @@ def repeat(
# CHECK: scf.yield %[[ACC_REDUCE]] : vector<1xf16>


# This test is to ensure that the we can handle multiple IV in reduction properly.
@run_test
def test_multiple_reduction_iv():
M = tkl.sym.M
N = tkl.sym.N
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 1, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, N, 0)]
constraints += [tkw.TilingConstraint(N, BLOCK_N)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16],
d: tkl.Memory[M, ADDRESS_SPACE, tkl.f16],
):
init_max = tkl.Register[M, tkl.f16](-1e6)
init_sum = tkl.Register[M, tkl.f16](0)

@tkw.reduction(N, init_args=[init_max, init_sum])
def repeat(
partial_max: tkl.Register[M, tkl.f16],
partial_sum: tkl.Register[M, tkl.f16],
) -> tkl.Register[M, tkl.f16]:
lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD)
partial_max = tkw.max(lhs, partial_max, dim=N)
partial_sum = tkw.sum(lhs, partial_sum, dim=N)
return partial_max, partial_sum

res_max, res_sum = repeat
tkw.write(res_max, c, elements_per_thread=1)
tkw.write(res_sum, d, elements_per_thread=1)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

shape = (256, 512)
a = torch.randn(shape, dtype=torch.float16)
c = torch.zeros((shape[0],), dtype=torch.float16)
d = torch.zeros((shape[0],), dtype=torch.float16)
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
BLOCK_M: 2,
BLOCK_N: 128,
ELEMS_PER_THREAD: 2,
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
):
print(test(a, c).module_op)
# CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index
# CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index
# CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index
# CHECK-DAG: %[[INIT_MAX:.+]] = arith.constant dense<0xFC00> : vector<1xf16>
# CHECK-DAG: %[[INIT_SUM:.+]] = arith.constant dense<0.000000e+00> : vector<1xf16>

# Tile Reduction Loop
# CHECK: %[[TILED:.+]]:4 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]]
# CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT_MAX]], %[[ACC1:.+]] = %[[INIT_SUM]], %[[ACC2:.+]] = %[[INIT_MAX]], %[[ACC3:.+]] = %[[INIT_SUM]])
# CHECK-SAME: -> (vector<1xf16>, vector<1xf16>, vector<1xf16>, vector<1xf16>) {
# 1st Expanded Local Max Reduction
# CHECK: arith.maximumf {{.*}} : vector<1xf16>
# 1st Expanded Global Max Reduction
# CHECK-COUNT-6: gpu.shuffle xor
# 1st Expanded Accumulator Max Reduction
# CHECK: %[[ACC_MAX_0:.+]] = arith.maximumf %[[ACC0]], %{{.*}}

# 2nd Expanded Local Max Reduction
# CHECK: arith.maximumf {{.*}} : vector<1xf16>
# 2nd Expanded Global Max Reduction
# CHECK-COUNT-6: gpu.shuffle xor
# 2nd Expanded Accumulator Max Reduction
# CHECK: %[[ACC_MAX_1:.+]] = arith.maximumf %[[ACC2]], %{{.*}}

# 1st Expanded Local Sum Reduction
# CHECK: arith.addf {{.*}} : vector<1xf16>
# 1st Expanded Global Sum Reduction
# CHECK-COUNT-6: gpu.shuffle xor
# 1st Expanded Accumulator Sum Reduction
# CHECK: %[[ACC_SUM_0:.+]] = arith.addf %[[ACC1]], %{{.*}}

# 2nd Expanded Local Sum Reduction
# CHECK: arith.addf {{.*}} : vector<1xf16>
# 2nd Expanded Global Sum Reduction
# CHECK-COUNT-6: gpu.shuffle xor
# 2nd Expanded Accumulator Sum Reduction
# CHECK: %[[ACC_SUM_1:.+]] = arith.addf %[[ACC3]], %{{.*}}

# CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]]


@run_test
def test_binary_lowerings():
constraints: list[tkw.Constraint] = [
Expand Down
2 changes: 1 addition & 1 deletion shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ def post_expansion(self, constraints: list["Constraint"]) -> None:
pass


@define_py_op(operator.getitem)
@define_py_op(operator.add)
@define_py_op(operator.sub)
@define_py_op(operator.mul)
Expand Down Expand Up @@ -945,6 +944,7 @@ def register_index(self) -> dict[IndexSymbol, IndexSequence]:
return custom.index


@define_py_op(operator.getitem)
@define_op("get_result")
@dataclass
class GetResult(CustomOp):
Expand Down
89 changes: 61 additions & 28 deletions shark_turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from ..lang.global_symbols import *

logger = get_logger("turbine.wave.expansion")
# This represents a mapping of a node + indexing into the dimensions to the
# corresponding expanded node in these specific dimensions. An example for a
# record in this map is (read_0_0_0, ((M,0),(N,0),(K,1)) -> read_0_0_1
# This represents a mapping of a node + indexing + res_idx(output index for op with multiple results)
# of node into the dimensions to the corresponding expanded node in these specific dimensions.
# An example for a record in this map is (read_0_0_0, ((M,0),(N,0),(K,1), 0) -> read_0_0_1.
ExpandedNodeMap: TypeAlias = dict[
tuple[CustomOp, tuple[tuple[IndexSymbol, int], ...]], CustomOp
tuple[CustomOp, tuple[tuple[IndexSymbol, int], int, ...]], CustomOp
]


Expand Down Expand Up @@ -302,30 +302,39 @@ def _expand_node(
dim_scaling: dict[IndexSymbol, int],
node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None],
context: ExpandedNodeMap,
res_idx: int = 0,
) -> CustomOp:
"""Expand a single node or list of nodes in specific dimensions and recursively proceed to its inputs."""
if isinstance(node, list):
expanded_nodes = []
for elem in node:
expanded_nodes.append(
_expand_node(
elem, trace, dim_query, dim_scaling, node_index_setter, context
elem,
trace,
dim_query,
dim_scaling,
node_index_setter,
context,
res_idx,
).fx_node
)
return expanded_nodes
# If we expanded a node in the same dimensions before, we can reuse it
if (node, get_indexed_dims(dim_query, node)) in context:
if (node, get_indexed_dims(dim_query, node), res_idx) in context:
logger.debug(f"Already expanded node: {node} in {dim_query}")
return context[(node, get_indexed_dims(dim_query, node))]
return context[(node, get_indexed_dims(dim_query, node), res_idx)]
elif isinstance(node, Reduction):
return _expand_reduction(
node, trace, dim_query, dim_scaling, node_index_setter, context
)
elif isinstance(node, GetResult):
elif isinstance(node, Getitem):
res_idx = node.res_idx
elif isinstance(node, GetResult) and not isinstance(node, Getitem):
# The presence of a GetResult node indicates that the reduction has already
# been expanded. Simply return the corresponding node.
reduction = get_custom(node.value)
return context[(reduction, get_indexed_dims(dim_query, reduction))]
return context[(reduction, get_indexed_dims(dim_query, reduction), res_idx)]
elif isinstance(node, Allocate):
# Allocate nodes are not expanded.
return node
Expand Down Expand Up @@ -371,12 +380,13 @@ def _expand_node(
dim_scaling,
node_index_setter,
context,
res_idx,
)
new_node.update_arg(i, new_arg)

new_node.post_expansion(constraints)

context[(node, get_indexed_dims(restricted_dims, node))] = new_node
context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node
return new_node


Expand All @@ -387,6 +397,7 @@ def _expand_reduction(
dim_scaling: dict[IndexSymbol, int],
node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None],
context: ExpandedNodeMap,
res_idx: int = 0,
) -> CustomOp:
"""Expand a reduction in a specific dimension and recursively proceed to its inputs."""
# Determine the dimensions to expand the reduction from the indexing of its users
Expand All @@ -409,44 +420,59 @@ def _expand_reduction(
new_output_args = []
new_init_args = []
for dim_vals in get_dim_combinations(dim_scaling, expand_dims):
for arg_idx, arg in output.node_args.items():
dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)}
return_vals = output.return_vals[0]
dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)}
if not isinstance(return_vals, Sequence):
return_vals = [return_vals]
for arg_idx, arg in enumerate(return_vals):
arg = get_custom(arg)
# Add GetResult nodes for the corresponding dimensions
reduction.graph.inserting_after(reduction.fx_node)
new_node = GetResult(reduction.fx_node, len(new_output_args))
new_node.add_to_graph(reduction.graph)
new_node.fx_node.name = get_expanded_name(new_node, dims)
context[(reduction, get_indexed_dims(dims, expand_dims))] = new_node
context[
(reduction, get_indexed_dims(dims, expand_dims), arg_idx)
] = new_node

# Proceed with expansion inside the reduction
new_output_args.append(
_expand_node(arg, trace, dims, dim_scaling, node_index_setter, context)
_expand_node(
arg, trace, dims, dim_scaling, node_index_setter, context, res_idx
)
)

# Proceed with expansion outside the reduction
for init_arg in reduction.init_args:
new_init_args.append(
_expand_node(
get_custom(init_arg),
trace,
dims,
dim_scaling,
node_index_setter,
context,
)
# Proceed with expansion outside the reduction
for init_arg in reduction.init_args:
new_init_args.append(
_expand_node(
get_custom(init_arg),
trace,
dims,
dim_scaling,
node_index_setter,
context,
res_idx,
)
)

# Update init_args and return values
reduction.update_arg(
"init_args", [new_init_arg.fx_node for new_init_arg in new_init_args]
)
output.update_arg("return_vals", [node.fx_node for node in new_output_args])
_handle_reduction_dim(
reduction, output, trace, dim_scaling, node_index_setter, context
reduction,
output,
trace,
dim_scaling,
node_index_setter,
context,
res_idx,
)
# Even though we expanded the reduction in multiple dimensions, we only return
# the node corresponding to the original query
return context[(reduction, get_indexed_dims(dim_query, expand_dims))]
return context[(reduction, get_indexed_dims(dim_query, expand_dims), res_idx)]


def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str:
Expand Down Expand Up @@ -536,6 +562,7 @@ def _handle_reduction_dim(
dim_scaling: dict[IndexSymbol, int],
node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None],
context: ExpandedNodeMap,
res_idx: int,
):
# Rediscover iter args
# TODO: Register iter args with the reduction initially so accessing them is easier
Expand Down Expand Up @@ -572,7 +599,13 @@ def _handle_reduction_dim(
saved_arg = user.node_args[index]
user.update_arg(index, dummy)
new_node = _expand_node(
user, trace, dims, dim_scaling, node_index_setter, context
user,
trace,
dims,
dim_scaling,
node_index_setter,
context,
res_idx,
)

# This expansion always happens, user should never be reused
Expand Down
13 changes: 13 additions & 0 deletions shark_turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,19 @@ def is_removable_operator(node: fx.Node) -> bool:
get_custom(node).graph.erase_node(node)


def remove_chained_getresult(trace: CapturedTrace):
def is_chained_getresult(node: fx.Node) -> bool:
custom = get_custom(node)
return isinstance(custom, GetResult) and isinstance(
get_custom(custom.value), GetResult
)

while removable_nodes := trace.walk(is_chained_getresult):
for node in removable_nodes:
get_custom(node).replace_all_uses_with(get_custom(node).value)
get_custom(node).graph.erase_node(node)


def delinearize_index(index: IndexExpr, shape: list[int]) -> list[IndexExpr]:
"""
Delinearizes a 1D index into a multi-dimensional index
Expand Down
10 changes: 9 additions & 1 deletion shark_turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from .expansion import expand_graph
from .promotion import promote_placeholders
from .hoisting import hoist_allocs
from .utils import canonicalize_module, compile_and_invoke, safe_subs
from .utils import (
canonicalize_module,
compile_and_invoke,
safe_subs,
remove_chained_getresult,
)
from .minimize_global_loads import minimize_global_loads
from .decompose_reduce_ops import decompose_reduce_ops
from .barriers import add_shared_memory_barriers
Expand Down Expand Up @@ -205,6 +210,9 @@ def _trace_and_get_kernel_signature(
# Expansion
expand_graph(graph, self.constraints)

# Clean up chains of GetResults
remove_chained_getresult(graph)

# Register analysis to determine register shapes.
determine_register_shape(graph, self.constraints)

Expand Down
Loading

0 comments on commit e2cc0b9

Please sign in to comment.