From e2cc0b9e2f35e34b6411ac7212718160af63d1fb Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Fri, 27 Sep 2024 00:24:24 -0700 Subject: [PATCH] [TKW] Implement support for multiple iter args on Reduction (#166) The main motivation behind this PR is to enable multiple induction variable/iterArg on the same tiled "Reduction" loop. To enable above we did a couple things: 1. Enable lowering/expansion on `operator.getitem` (the op that extract multiple results in python i.e `res0, res1 = fn`) by templating it on`GetResult(CustomOp)` since they have the same args and interface and can reuse most of the indexing/expansion helper. 2. Introduce `res_idx`, a variable to represent which result index of an op we are referring to, during expansion and context map. This is useful for ops that has more than one results / variables as outputs. 3. bug fix in expand_reduction, where we hoist out iterating and expanding of `reduction.init_args` out of the loop that iterates and expands over the `yield`/`return_val` of the reduction loop. It is expected that the size of `init_args` is the same as size of `yield`/`return_val`. Hence if we had N iter_args/yields, we ended up expanding the `init_args` N x N time instead of N times. We haven't seen it thus far because we have been only playing with 1 init_arg/iterArg, and 1x1 == 1. 4. Introduce a canonicalization pattern to fold chains of GetResult. this is because GetResult by semantic/design is only expected to extract and have one result. Hence a chain of GetResult should just be replaced by itself. This help clean up the IR. num.4 also helps circumvent issue where Reduction and GetResult is expanded completely by itself not following the DFS structure per dimension like the rest of the expansion code. This becomes especially problematic for multiple IterArg since Getitem is not expecting its' source value to be expanded without it. --------- Signed-off-by: Stanley Winata --- lit_tests/kernel/wave/codegen.py | 105 +++++++++++++++++++++++++ shark_turbine/kernel/ops/wave_ops.py | 2 +- shark_turbine/kernel/wave/expansion.py | 89 ++++++++++++++------- shark_turbine/kernel/wave/utils.py | 13 +++ shark_turbine/kernel/wave/wave.py | 10 ++- tests/kernel/wave/wave_e2e_test.py | 38 +++++---- 6 files changed, 212 insertions(+), 45 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index d102f3533..e18764b1e 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -946,6 +946,111 @@ def repeat( # CHECK: scf.yield %[[ACC_REDUCE]] : vector<1xf16> +# This test is to ensure that the we can handle multiple IV in reduction properly. +@run_test +def test_multiple_reduction_iv(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, N, 0)] + constraints += [tkw.TilingConstraint(N, BLOCK_N)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + d: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + init_max = tkl.Register[M, tkl.f16](-1e6) + init_sum = tkl.Register[M, tkl.f16](0) + + @tkw.reduction(N, init_args=[init_max, init_sum]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], + partial_sum: tkl.Register[M, tkl.f16], + ) -> tkl.Register[M, tkl.f16]: + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + partial_max = tkw.max(lhs, partial_max, dim=N) + partial_sum = tkw.sum(lhs, partial_sum, dim=N) + return partial_max, partial_sum + + res_max, res_sum = repeat + tkw.write(res_max, c, elements_per_thread=1) + tkw.write(res_sum, d, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 512) + a = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + d = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, c).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[INIT_MAX:.+]] = arith.constant dense<0xFC00> : vector<1xf16> + # CHECK-DAG: %[[INIT_SUM:.+]] = arith.constant dense<0.000000e+00> : vector<1xf16> + + # Tile Reduction Loop + # CHECK: %[[TILED:.+]]:4 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT_MAX]], %[[ACC1:.+]] = %[[INIT_SUM]], %[[ACC2:.+]] = %[[INIT_MAX]], %[[ACC3:.+]] = %[[INIT_SUM]]) + # CHECK-SAME: -> (vector<1xf16>, vector<1xf16>, vector<1xf16>, vector<1xf16>) { + # 1st Expanded Local Max Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 1st Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Max Reduction + # CHECK: %[[ACC_MAX_0:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + + # 2nd Expanded Local Max Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 2nd Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Max Reduction + # CHECK: %[[ACC_MAX_1:.+]] = arith.maximumf %[[ACC2]], %{{.*}} + + # 1st Expanded Local Sum Reduction + # CHECK: arith.addf {{.*}} : vector<1xf16> + # 1st Expanded Global Sum Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Sum Reduction + # CHECK: %[[ACC_SUM_0:.+]] = arith.addf %[[ACC1]], %{{.*}} + + # 2nd Expanded Local Sum Reduction + # CHECK: arith.addf {{.*}} : vector<1xf16> + # 2nd Expanded Global Sum Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Sum Reduction + # CHECK: %[[ACC_SUM_1:.+]] = arith.addf %[[ACC3]], %{{.*}} + + # CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]] + + @run_test def test_binary_lowerings(): constraints: list[tkw.Constraint] = [ diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 905095c64..2c38c9c28 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -486,7 +486,6 @@ def post_expansion(self, constraints: list["Constraint"]) -> None: pass -@define_py_op(operator.getitem) @define_py_op(operator.add) @define_py_op(operator.sub) @define_py_op(operator.mul) @@ -945,6 +944,7 @@ def register_index(self) -> dict[IndexSymbol, IndexSequence]: return custom.index +@define_py_op(operator.getitem) @define_op("get_result") @dataclass class GetResult(CustomOp): diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index ebbc2d46e..697850313 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -23,11 +23,11 @@ from ..lang.global_symbols import * logger = get_logger("turbine.wave.expansion") -# This represents a mapping of a node + indexing into the dimensions to the -# corresponding expanded node in these specific dimensions. An example for a -# record in this map is (read_0_0_0, ((M,0),(N,0),(K,1)) -> read_0_0_1 +# This represents a mapping of a node + indexing + res_idx(output index for op with multiple results) +# of node into the dimensions to the corresponding expanded node in these specific dimensions. +# An example for a record in this map is (read_0_0_0, ((M,0),(N,0),(K,1), 0) -> read_0_0_1. ExpandedNodeMap: TypeAlias = dict[ - tuple[CustomOp, tuple[tuple[IndexSymbol, int], ...]], CustomOp + tuple[CustomOp, tuple[tuple[IndexSymbol, int], int, ...]], CustomOp ] @@ -302,6 +302,7 @@ def _expand_node( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int = 0, ) -> CustomOp: """Expand a single node or list of nodes in specific dimensions and recursively proceed to its inputs.""" if isinstance(node, list): @@ -309,23 +310,31 @@ def _expand_node( for elem in node: expanded_nodes.append( _expand_node( - elem, trace, dim_query, dim_scaling, node_index_setter, context + elem, + trace, + dim_query, + dim_scaling, + node_index_setter, + context, + res_idx, ).fx_node ) return expanded_nodes # If we expanded a node in the same dimensions before, we can reuse it - if (node, get_indexed_dims(dim_query, node)) in context: + if (node, get_indexed_dims(dim_query, node), res_idx) in context: logger.debug(f"Already expanded node: {node} in {dim_query}") - return context[(node, get_indexed_dims(dim_query, node))] + return context[(node, get_indexed_dims(dim_query, node), res_idx)] elif isinstance(node, Reduction): return _expand_reduction( node, trace, dim_query, dim_scaling, node_index_setter, context ) - elif isinstance(node, GetResult): + elif isinstance(node, Getitem): + res_idx = node.res_idx + elif isinstance(node, GetResult) and not isinstance(node, Getitem): # The presence of a GetResult node indicates that the reduction has already # been expanded. Simply return the corresponding node. reduction = get_custom(node.value) - return context[(reduction, get_indexed_dims(dim_query, reduction))] + return context[(reduction, get_indexed_dims(dim_query, reduction), res_idx)] elif isinstance(node, Allocate): # Allocate nodes are not expanded. return node @@ -371,12 +380,13 @@ def _expand_node( dim_scaling, node_index_setter, context, + res_idx, ) new_node.update_arg(i, new_arg) new_node.post_expansion(constraints) - context[(node, get_indexed_dims(restricted_dims, node))] = new_node + context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node return new_node @@ -387,6 +397,7 @@ def _expand_reduction( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int = 0, ) -> CustomOp: """Expand a reduction in a specific dimension and recursively proceed to its inputs.""" # Determine the dimensions to expand the reduction from the indexing of its users @@ -409,32 +420,41 @@ def _expand_reduction( new_output_args = [] new_init_args = [] for dim_vals in get_dim_combinations(dim_scaling, expand_dims): - for arg_idx, arg in output.node_args.items(): - dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} + return_vals = output.return_vals[0] + dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} + if not isinstance(return_vals, Sequence): + return_vals = [return_vals] + for arg_idx, arg in enumerate(return_vals): + arg = get_custom(arg) # Add GetResult nodes for the corresponding dimensions reduction.graph.inserting_after(reduction.fx_node) new_node = GetResult(reduction.fx_node, len(new_output_args)) new_node.add_to_graph(reduction.graph) new_node.fx_node.name = get_expanded_name(new_node, dims) - context[(reduction, get_indexed_dims(dims, expand_dims))] = new_node + context[ + (reduction, get_indexed_dims(dims, expand_dims), arg_idx) + ] = new_node # Proceed with expansion inside the reduction new_output_args.append( - _expand_node(arg, trace, dims, dim_scaling, node_index_setter, context) + _expand_node( + arg, trace, dims, dim_scaling, node_index_setter, context, res_idx + ) ) - # Proceed with expansion outside the reduction - for init_arg in reduction.init_args: - new_init_args.append( - _expand_node( - get_custom(init_arg), - trace, - dims, - dim_scaling, - node_index_setter, - context, - ) + # Proceed with expansion outside the reduction + for init_arg in reduction.init_args: + new_init_args.append( + _expand_node( + get_custom(init_arg), + trace, + dims, + dim_scaling, + node_index_setter, + context, + res_idx, ) + ) # Update init_args and return values reduction.update_arg( @@ -442,11 +462,17 @@ def _expand_reduction( ) output.update_arg("return_vals", [node.fx_node for node in new_output_args]) _handle_reduction_dim( - reduction, output, trace, dim_scaling, node_index_setter, context + reduction, + output, + trace, + dim_scaling, + node_index_setter, + context, + res_idx, ) # Even though we expanded the reduction in multiple dimensions, we only return # the node corresponding to the original query - return context[(reduction, get_indexed_dims(dim_query, expand_dims))] + return context[(reduction, get_indexed_dims(dim_query, expand_dims), res_idx)] def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: @@ -536,6 +562,7 @@ def _handle_reduction_dim( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int, ): # Rediscover iter args # TODO: Register iter args with the reduction initially so accessing them is easier @@ -572,7 +599,13 @@ def _handle_reduction_dim( saved_arg = user.node_args[index] user.update_arg(index, dummy) new_node = _expand_node( - user, trace, dims, dim_scaling, node_index_setter, context + user, + trace, + dims, + dim_scaling, + node_index_setter, + context, + res_idx, ) # This expansion always happens, user should never be reused diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index 42e5bca3f..dda9013bd 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -123,6 +123,19 @@ def is_removable_operator(node: fx.Node) -> bool: get_custom(node).graph.erase_node(node) +def remove_chained_getresult(trace: CapturedTrace): + def is_chained_getresult(node: fx.Node) -> bool: + custom = get_custom(node) + return isinstance(custom, GetResult) and isinstance( + get_custom(custom.value), GetResult + ) + + while removable_nodes := trace.walk(is_chained_getresult): + for node in removable_nodes: + get_custom(node).replace_all_uses_with(get_custom(node).value) + get_custom(node).graph.erase_node(node) + + def delinearize_index(index: IndexExpr, shape: list[int]) -> list[IndexExpr]: """ Delinearizes a 1D index into a multi-dimensional index diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index eb6003de3..4d19d99fb 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -23,7 +23,12 @@ from .expansion import expand_graph from .promotion import promote_placeholders from .hoisting import hoist_allocs -from .utils import canonicalize_module, compile_and_invoke, safe_subs +from .utils import ( + canonicalize_module, + compile_and_invoke, + safe_subs, + remove_chained_getresult, +) from .minimize_global_loads import minimize_global_loads from .decompose_reduce_ops import decompose_reduce_ops from .barriers import add_shared_memory_barriers @@ -205,6 +210,9 @@ def _trace_and_get_kernel_signature( # Expansion expand_graph(graph, self.constraints) + # Clean up chains of GetResults + remove_chained_getresult(graph) + # Register analysis to determine register shapes. determine_register_shape(graph, self.constraints) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 0e55a3f9a..dbe884245 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -279,7 +279,7 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_tiled_reduce_max")) @xfail_unaligned -def test_tiled_reduce_max(shape): +def test_toy_online_softmax(shape): M = tkl.sym.M N = tkl.sym.N wave_size = 64 @@ -303,30 +303,38 @@ def test_tiled_reduce_max(shape): @tkw.wave(constraints) def test( - a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f32], ): - init_max = tkl.Register[M, tkl.f16](-1e6) + init_max = tkl.Register[M, tkl.f32](-1e6) + init_sum = tkl.Register[M, tkl.f32](0) - @tkw.reduction(N, init_args=[init_max]) + @tkw.reduction(N, init_args=[init_max, init_sum]) def repeat( - partial_max: tkl.Register[M, tkl.f16], - ) -> tkl.Register[M, tkl.f16]: + partial_max: tkl.Register[M, tkl.f32], + partial_sum: tkl.Register[M, tkl.f32], + ) -> tkl.Register[M, tkl.f32]: lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) rhs = tkw.read(b, elements_per_thread=ELEMS_PER_THREAD) res = lhs * rhs partial_max = tkw.max(res, partial_max, dim=N) - return partial_max + partial_sum = tkw.sum(res, partial_sum, dim=N) + return partial_max, partial_sum - tkw.write(repeat, c, elements_per_thread=1) + res_max, res_sum = repeat + result = res_max / res_sum + tkw.write(result, c, elements_per_thread=1) config = {"backend": "rocm", "device": "hip", "target": "gfx942"} - a = torch.randn(shape, dtype=torch.float16) - b = torch.randn(shape, dtype=torch.float16) - c = torch.zeros((shape[0],), dtype=torch.float16) - ref = torch.max((a * b), dim=-1) + torch.manual_seed(1) + a = torch.randn(shape, dtype=torch.float32) + b = torch.randn(shape, dtype=torch.float32) + c = torch.zeros((shape[0],), dtype=torch.float32) + ref_max = torch.max((a * b), dim=-1).values + ref_sum = torch.sum((a * b), dim=-1) + ref = ref_max / ref_sum with tk.gen.TestLaunchContext( { M: shape[0], @@ -343,7 +351,7 @@ def repeat( # Assert equal does cast to boolean on torch.Tensor # which causes issues, hence we cast to numpy before # checking. - assert_equal(c, ref.values.numpy()) + assert_allclose(ref, c, atol=0.015) @require_e2e