From e81301bea5995a722b441cb92f4b483d805f3b6d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 9 Dec 2024 07:47:24 +0100 Subject: [PATCH 1/3] WIP --- .../transformations/simplify.py | 37 ++- .../test_constant_substitution.py | 235 +++++++++++++++--- 2 files changed, 227 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 6b7bd1b6d5..fe676bd99a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -207,25 +207,48 @@ def gt_substitute_compiletime_symbols( validate_all: Perform validation also on intermediate steps. """ - # We will use the `replace` function of the top SDFG, however, lower levels - # are handled using ConstantPropagation. - sdfg.replace_dict(repl) + # The substitution is performed in this way for the following reasons: + # `ConstantPropagation` would propagate values that are used in conditions. + # But it would not replace the strides of an array a symbol used inside a + # Tasklet, except that Tasklet or the array are inside an Nested SDFG. + # However, `replace_dict()` does such things also on the top level. + # For that reason the function will first use this multi stage version. const_prop = dace_passes.ConstantPropagation() const_prop.recursive = True const_prop.progress = False + # Before we will do a first round of DaCe's constant propagation. Followed + # by simplify. The main reason is that if there is an access node in the + # SDFG we would generate an error, if we would use `replace_dict()`. + # Thus we use CP to handle them. + sdfg.view() const_prop.apply_pass( sdfg=sdfg, initial_symbols=repl, _=None, ) - gt_simplify( + if validate_all: + sdfg.validate() + + # New we use the `replace_dict()` function. This will get rid of symbols that + # were not handled by constant propagation above. This is mainly the case for + # Symbols used in Tasklets, data descriptors that were not inside a Nested + # SDFG. + sdfg.view() + sdfg.replace_dict(repl) + sdfg.view() + + # Now we will again run constant propagation followed by simplify. This is mostly + # a clean-up pass. + # TODO(phimuell): Investigate if it is needed. + print("=" * 80) + const_prop.apply_pass( sdfg=sdfg, - validate=validate, - validate_all=validate_all, + initial_symbols=repl, + _=None, ) - dace.sdfg.propagation.propagate_memlets_sdfg(sdfg) + sdfg.view() def gt_reduce_distributed_buffering( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py index 04a4f098ef..178ccbe775 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py @@ -18,43 +18,11 @@ from . import util -def test_constant_substitution(): - sdfg, nsdfg = _make_sdfg() +import dace - # Ensure that `One` is present. - assert len(sdfg.symbols) == 2 - assert len(nsdfg.sdfg.symbols) == 2 - assert len(nsdfg.symbol_mapping) == 2 - assert "One" in sdfg.symbols - assert "One" in nsdfg.sdfg.symbols - assert "One" in nsdfg.symbol_mapping - assert "One" == str(nsdfg.symbol_mapping["One"]) - assert all(str(desc.strides[1]) == "One" for desc in sdfg.arrays.values()) - assert all(str(desc.strides[1]) == "One" for desc in nsdfg.sdfg.arrays.values()) - assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) - assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) - assert "One" in sdfg.used_symbols(True) - # Now replace `One` with 1 - gtx_transformations.gt_substitute_compiletime_symbols(sdfg, {"One": 1}) - - assert len(sdfg.symbols) == 1 - assert len(nsdfg.sdfg.symbols) == 1 - assert len(nsdfg.symbol_mapping) == 1 - assert "One" not in sdfg.symbols - assert "One" not in nsdfg.sdfg.symbols - assert "One" not in nsdfg.symbol_mapping - assert all(desc.strides[1] == 1 and len(desc.strides) == 2 for desc in sdfg.arrays.values()) - assert all( - desc.strides[1] == 1 and len(desc.strides) == 2 for desc in nsdfg.sdfg.arrays.values() - ) - assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) - assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) - assert "One" not in sdfg.used_symbols(True) - - -def _make_nested_sdfg() -> dace.SDFG: - sdfg = dace.SDFG("nested") +def _make_nested_sdfg_test_constant_sub() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("nested")) N = dace.symbol(sdfg.add_symbol("N", dace.int32)) One = dace.symbol(sdfg.add_symbol("One", dace.int32)) for name in "ABC": @@ -81,8 +49,8 @@ def _make_nested_sdfg() -> dace.SDFG: return sdfg -def _make_sdfg() -> tuple[dace.SDFG, dace.nodes.NestedSDFG]: - sdfg = dace.SDFG("outer_sdfg") +def _make_sdfg_test_constant_sub() -> tuple[dace.SDFG, dace.nodes.NestedSDFG]: + sdfg = dace.SDFG(util.unique_name("outer_sdfg")) N = dace.symbol(sdfg.add_symbol("N", dace.int32)) One = dace.symbol(sdfg.add_symbol("One", dace.int32)) for name in "ABCD": @@ -96,7 +64,7 @@ def _make_sdfg() -> tuple[dace.SDFG, dace.nodes.NestedSDFG]: sdfg.arrays["C"].transient = True first_state: dace.SDFGState = sdfg.add_state(is_start_block=True) - nested_sdfg: dace.SDFG = _make_nested_sdfg() + nested_sdfg: dace.SDFG = _make_nested_sdfg_test_constant_sub() nsdfg = first_state.add_nested_sdfg( nested_sdfg, parent=sdfg, @@ -140,3 +108,194 @@ def _make_sdfg() -> tuple[dace.SDFG, dace.nodes.NestedSDFG]: ) sdfg.validate() return sdfg, nsdfg + + +def test_constant_substitution(): + sdfg, nsdfg = _make_sdfg_test_constant_sub() + + # Ensure that `One` is present. + assert len(sdfg.symbols) == 2 + assert len(nsdfg.sdfg.symbols) == 2 + assert len(nsdfg.symbol_mapping) == 2 + assert "One" in sdfg.symbols + assert "One" in nsdfg.sdfg.symbols + assert "One" in nsdfg.symbol_mapping + assert "One" == str(nsdfg.symbol_mapping["One"]) + assert all(str(desc.strides[1]) == "One" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[1]) == "One" for desc in nsdfg.sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" in sdfg.used_symbols(True) + + # Now replace `One` with 1 + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, {"One": 1, "N": 10}) + + assert len(sdfg.symbols) == 1 + assert len(nsdfg.sdfg.symbols) == 1 + assert len(nsdfg.symbol_mapping) == 1 + assert "One" not in sdfg.symbols + assert "One" not in nsdfg.sdfg.symbols + assert "One" not in nsdfg.symbol_mapping + assert all(desc.strides[1] == 1 and len(desc.strides) == 2 for desc in sdfg.arrays.values()) + assert all( + desc.strides[1] == 1 and len(desc.strides) == 2 for desc in nsdfg.sdfg.arrays.values() + ) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" not in sdfg.used_symbols(True) + + +def _make_not_wrapped_sdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("not_wrapped_sdfg")) + state = sdfg.add_state(is_start_block=True) + + sdfg.add_symbol("N", dace.int64) + sdfg.add_symbol("lim_area", dace.bool_) + + for name in "AB": + sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False) + + state.add_mapped_tasklet( + "PreComp", + map_ranges={"__i": "0:N"}, + inputs={"__in": dace.Memlet("A[__i]")}, + code="__out = __in + N", + outputs={"__out": dace.Memlet("B[__i]")}, + external_edges=True, + ) + + stateT = sdfg.add_state(is_start_block=False) + stateT.add_mapped_tasklet( + "Tcomp", + map_ranges={"__i": "0:N"}, + inputs={"__in": dace.Memlet("A[__i]")}, + code="__out = __in + N", + outputs={"__out": dace.Memlet("B[__i]")}, + external_edges=True, + ) + + stateF = sdfg.add_state(is_start_block=False) + stateF.add_mapped_tasklet( + "Fcomp", + map_ranges={"__i": "0:N"}, + inputs={"__in": dace.Memlet("A[__i]")}, + code="__out = __in + 2 * N", + outputs={"__out": dace.Memlet("B[__i]")}, + external_edges=True, + ) + + stateJ = sdfg.add_state(is_start_block=False) + sdfg.add_edge(state, stateT, dace.InterstateEdge(condition="lim_area")) + sdfg.add_edge(state, stateF, dace.InterstateEdge(condition="not lim_area")) + sdfg.add_edge(stateT, stateJ, dace.InterstateEdge()) + sdfg.add_edge(stateF, stateJ, dace.InterstateEdge()) + sdfg.validate() + return sdfg + + +def _make_wrapped_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg = dace.SDFG(util.unique_name("wrapped_sdfg")) + state = sdfg.add_state("WRAPPED", is_start_block=True) + sdfg.add_symbol("lim_area", dace.bool_) + sdfg.add_symbol("N", dace.bool_) + for name in "AB": + sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False) + + nsdfg = state.add_nested_sdfg( + sdfg=_make_not_wrapped_sdfg(), + parent=sdfg, + inputs={"A"}, + outputs={"B"}, + symbol_mapping={"lim_area": "lim_area", "N": "N"}, + ) + state.add_edge( + state.add_access("A"), None, nsdfg, "A", dace.Memlet.from_array("A", sdfg.arrays["A"]) + ) + state.add_edge( + nsdfg, "B", state.add_access("B"), None, dace.Memlet.from_array("B", sdfg.arrays["B"]) + ) + sdfg.validate() + return sdfg, nsdfg + + +def test_constant_substitution_not_wrapped_sdfg(): + sdfg: dace.SDFG = _make_not_wrapped_sdfg() + assert sdfg.number_of_nodes() > 1 + assert sdfg.free_symbols == {"N", "lim_area"} + map_entries_old: list[dace_nodes.MapEntry] = util.count_nodes( + sdfg, + node_type=dace_nodes.MapEntry, + return_nodes=True, + ) + assert any(str(map_entry.map.range[0][1] + 1) == "N" for map_entry in map_entries_old) + assert all( + node.desc(sdfg).shape == ("N",) + for node in sdfg.all_nodes_recursive() + if isinstance(node, dace_nodes.AccessNode) + ) + + gtx_transformations.gt_substitute_compiletime_symbols( + sdfg, + repl={"N": 10, "lim_area": True}, + validate=True, + validate_all=True, + ) + assert sdfg.number_of_nodes() == 4 + assert len(sdfg.free_symbols) == 0 + map_entries: list[dace_nodes.MapEntry] = util.count_nodes( + sdfg, + node_type=dace_nodes.MapEntry, + return_nodes=True, + ) + assert len(map_entries) == 2 + assert any(str(map_entry.map.range[0][1]) == "9" for map_entry in map_entries) + assert all( + node.desc(sdfg).shape == (10,) + for node in sdfg.all_nodes_recursive() + if isinstance(node, dace_nodes.AccessNode) + ) + + +def test_constant_substitution_wrapped_sdfg(): + sdfg, nsdfg = _make_wrapped_sdfg() + assert sdfg.number_of_nodes() == 1 + assert sdfg.free_symbols == {"N", "lim_area"} + assert util.count_nodes(sdfg, dace_nodes.NestedSDFG) == 1 + + map_entries_old: list[dace_nodes.MapEntry] = util.count_nodes( + nsdfg.sdfg, + dace_nodes.MapEntry, + return_nodes=True, + ) + assert any(str(map_entry.map.range[0][1] + 1) == "N" for map_entry in map_entries_old) + assert all( + node.desc(nsdfg.sdfg).shape == ("N",) + for node in sdfg.all_nodes_recursive() + if isinstance(node, dace_nodes.AccessNode) + ) + + sdfg.view() + gtx_transformations.gt_substitute_compiletime_symbols( + sdfg, + repl={"N": "10", "lim_area": 1}, + validate=True, + validate_all=True, + ) + + assert sdfg.number_of_nodes() == 1 + assert util.count_nodes(sdfg, dace_nodes.NestedSDFG) == 1 + assert len(sdfg.free_symbols) == 0 + for nsdfg in sdfg.all_sdfgs_recursive(): + assert all( + node.desc(nsdfg).shape == (10,) + for node in sdfg.all_nodes_recursive() + if isinstance(node, dace_nodes.AccessNode) + ) + + map_entries: list[dace_nodes.MapEntry] = util.count_nodes( + nsdfg.sdfg, + dace_nodes.MapEntry, + return_nodes=True, + ) + assert len(map_entries) == 2 + assert any(str(map_entry.map.range[0][1]) == "9" for map_entry in map_entries) From 599ec9d50578679c3ff76b0d4941d34fb6de9fc1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 9 Dec 2024 15:46:35 +0100 Subject: [PATCH 2/3] Changed the constant substitution implementation. This is a bit more stable, but it it does not seem to work yet, if there is an access node that should be replaced. Which was the actuall goal of this modification. However, it solves some problems with multiple states, at least it is now a bit more consistent in that case. --- .../transformations/simplify.py | 53 ++++++++----------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index fe676bd99a..e588ebeab0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -205,50 +205,41 @@ def gt_substitute_compiletime_symbols( repl: Maps the name of the symbol to the value it should be replaced with. validate: Perform validation at the end of the function. validate_all: Perform validation also on intermediate steps. + + Note: + Because of [issue 1817](https://github.com/spcl/dace/issues/1817) in DaCe, + the function has to run `gt_simplify()`. However, this is an artefact of + the implementation and will be changed once the bug is solved. """ - # The substitution is performed in this way for the following reasons: - # `ConstantPropagation` would propagate values that are used in conditions. - # But it would not replace the strides of an array a symbol used inside a - # Tasklet, except that Tasklet or the array are inside an Nested SDFG. - # However, `replace_dict()` does such things also on the top level. - # For that reason the function will first use this multi stage version. + # Ideally this function would just call `ConstantPropagation` with the replacement + # `dict` and be done. However, because of [issue 1817](https://github.com/spcl/dace/issues/1817) + # in DaCe this is not possible and we have to do it in this awkward way. + # TODO(phimuell): Fix this strange behaviour. + # First we do replacement on the top level SDFG only. However, we have to filter + # out all names that refers to data descriptors, because the replacement function + # can not handle them. We leave this to `ConstantPropagation`. + arrays = sdfg.arrays + sdfg.replace_dict({sym: value for sym, value in repl.items() if sym not in arrays}) const_prop = dace_passes.ConstantPropagation() const_prop.recursive = True const_prop.progress = False - - # Before we will do a first round of DaCe's constant propagation. Followed - # by simplify. The main reason is that if there is an access node in the - # SDFG we would generate an error, if we would use `replace_dict()`. - # Thus we use CP to handle them. - sdfg.view() const_prop.apply_pass( sdfg=sdfg, initial_symbols=repl, _=None, ) - if validate_all: - sdfg.validate() - # New we use the `replace_dict()` function. This will get rid of symbols that - # were not handled by constant propagation above. This is mainly the case for - # Symbols used in Tasklets, data descriptors that were not inside a Nested - # SDFG. - sdfg.view() - sdfg.replace_dict(repl) - sdfg.view() - - # Now we will again run constant propagation followed by simplify. This is mostly - # a clean-up pass. - # TODO(phimuell): Investigate if it is needed. - print("=" * 80) - const_prop.apply_pass( - sdfg=sdfg, - initial_symbols=repl, - _=None, + # To handle some bugs in `ConstantPropagation` we now call simplify. + # TODO(phimuell): Once the bug in DaCe is fixed remove this. + gt_simplify( + sdfg, + validate=False, + validate_all=validate_all, ) - sdfg.view() + if validate_all: + sdfg.validate() def gt_reduce_distributed_buffering( From 9cbf6bc0d0a7c57ff143f85c5ce4bb972db6ad84 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 9 Dec 2024 15:50:51 +0100 Subject: [PATCH 3/3] Updated the tests for the constant substituter. --- .../test_constant_substitution.py | 475 ++++++++++-------- 1 file changed, 267 insertions(+), 208 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py index 178ccbe775..902554d851 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest +from typing import Any, Final, Iterable, Optional, TypeAlias, Union, Literal dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes @@ -18,173 +19,81 @@ from . import util -import dace - - -def _make_nested_sdfg_test_constant_sub() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("nested")) - N = dace.symbol(sdfg.add_symbol("N", dace.int32)) - One = dace.symbol(sdfg.add_symbol("One", dace.int32)) - for name in "ABC": - sdfg.add_array( - name=name, - dtype=dace.float64, - shape=(N, N), - strides=(N, One), - transient=False, - ) - state = sdfg.add_state(is_start_block=True) - state.add_mapped_tasklet( - "computation", - map_ranges={"__i0": "0:N", "__i1": "0:N"}, - inputs={ - "__in0": dace.Memlet("A[__i0, __i1]"), - "__in1": dace.Memlet("B[__i0, __i1]"), - }, - code="__out = __in0 + __in1", - outputs={"__out": dace.Memlet("C[__i0, __i1]")}, - external_edges=True, +def _check_shapes( + sdfg: dace.SDFG, + expected_shape: tuple[str, ...], + to_string: bool = True, +) -> bool: + return all( + tuple((str(s) if to_string else s) for s in desc.shape) == expected_shape + for desc in sdfg.arrays.values() ) - sdfg.validate() - return sdfg -def _make_sdfg_test_constant_sub() -> tuple[dace.SDFG, dace.nodes.NestedSDFG]: - sdfg = dace.SDFG(util.unique_name("outer_sdfg")) - N = dace.symbol(sdfg.add_symbol("N", dace.int32)) - One = dace.symbol(sdfg.add_symbol("One", dace.int32)) - for name in "ABCD": - sdfg.add_array( - name=name, - dtype=dace.float64, - shape=(N, N), - strides=(N, One), - transient=False, - ) - sdfg.arrays["C"].transient = True - - first_state: dace.SDFGState = sdfg.add_state(is_start_block=True) - nested_sdfg: dace.SDFG = _make_nested_sdfg_test_constant_sub() - nsdfg = first_state.add_nested_sdfg( - nested_sdfg, - parent=sdfg, - inputs={"A", "B"}, - outputs={"C"}, - symbol_mapping={"One": "One", "N": "N"}, - ) - first_state.add_edge( - first_state.add_access("A"), - None, - nsdfg, - "A", - dace.Memlet("A[0:N, 0:N]"), - ) - first_state.add_edge( - first_state.add_access("B"), - None, - nsdfg, - "B", - dace.Memlet("B[0:N, 0:N]"), +def _check_maps( + sdfg: dace.SDFG, + expected_end: str, +) -> bool: + map_entries: list[dace_nodes.MapEntry] = util.count_nodes( + graph=sdfg, + node_type=dace_nodes.MapEntry, + return_nodes=True, ) - first_state.add_edge( - nsdfg, - "C", - first_state.add_access("C"), - None, - dace.Memlet("C[0:N, 0:N]"), + return all( + str(map_entry.map.range.ranges[0][1] + 1) == expected_end for map_entry in map_entries ) - second_state: dace.SDFGState = sdfg.add_state_after(first_state) - second_state.add_mapped_tasklet( - "outer_computation", - map_ranges={"__i0": "0:N", "__i1": "0:N"}, - inputs={ - "__in0": dace.Memlet("A[__i0, __i1]"), - "__in1": dace.Memlet("C[__i0, __i1]"), - }, - code="__out = __in0 * __in1", - outputs={"__out": dace.Memlet("D[__i0, __i1]")}, - external_edges=True, - ) - sdfg.validate() - return sdfg, nsdfg +def _check_tasklets( + sdfg: dace.SDFG, + expected_symbols: Optional[set[str]] = None, + forbidden_symbols: Optional[set[str]] = None, +) -> bool: + assert not ((expected_symbols is None) and (forbidden_symbols is None)) + expected_symbols = expected_symbols or set() + forbidden_symbols = forbidden_symbols or set() -def test_constant_substitution(): - sdfg, nsdfg = _make_sdfg_test_constant_sub() - - # Ensure that `One` is present. - assert len(sdfg.symbols) == 2 - assert len(nsdfg.sdfg.symbols) == 2 - assert len(nsdfg.symbol_mapping) == 2 - assert "One" in sdfg.symbols - assert "One" in nsdfg.sdfg.symbols - assert "One" in nsdfg.symbol_mapping - assert "One" == str(nsdfg.symbol_mapping["One"]) - assert all(str(desc.strides[1]) == "One" for desc in sdfg.arrays.values()) - assert all(str(desc.strides[1]) == "One" for desc in nsdfg.sdfg.arrays.values()) - assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) - assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) - assert "One" in sdfg.used_symbols(True) - - # Now replace `One` with 1 - gtx_transformations.gt_substitute_compiletime_symbols(sdfg, {"One": 1, "N": 10}) - - assert len(sdfg.symbols) == 1 - assert len(nsdfg.sdfg.symbols) == 1 - assert len(nsdfg.symbol_mapping) == 1 - assert "One" not in sdfg.symbols - assert "One" not in nsdfg.sdfg.symbols - assert "One" not in nsdfg.symbol_mapping - assert all(desc.strides[1] == 1 and len(desc.strides) == 2 for desc in sdfg.arrays.values()) - assert all( - desc.strides[1] == 1 and len(desc.strides) == 2 for desc in nsdfg.sdfg.arrays.values() + tasklets: list[dace_nodes.Tasklet] = util.count_nodes( + graph=sdfg, + node_type=dace_nodes.Tasklet, + return_nodes=True, ) - assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) - assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) - assert "One" not in sdfg.used_symbols(True) - + if not all(expected_symbols.issubset(tasklet.free_symbols) for tasklet in tasklets): + return False + if not all(forbidden_symbols.isdisjoint(tasklet.free_symbols) for tasklet in tasklets): + return False + return True -def _make_not_wrapped_sdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("not_wrapped_sdfg")) - state = sdfg.add_state(is_start_block=True) +def make_multi_state_sdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("multi_state_sdfg")) + state = sdfg.add_state("stateS", is_start_block=True) sdfg.add_symbol("N", dace.int64) sdfg.add_symbol("lim_area", dace.bool_) - for name in "AB": sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False) - state.add_mapped_tasklet( - "PreComp", - map_ranges={"__i": "0:N"}, - inputs={"__in": dace.Memlet("A[__i]")}, - code="__out = __in + N", - outputs={"__out": dace.Memlet("B[__i]")}, - external_edges=True, - ) - - stateT = sdfg.add_state(is_start_block=False) + stateT = sdfg.add_state("stateT", is_start_block=False) stateT.add_mapped_tasklet( "Tcomp", map_ranges={"__i": "0:N"}, inputs={"__in": dace.Memlet("A[__i]")}, - code="__out = __in + N", + code="__out = (__in + 2 * N) if lim_area else (__in - 3 * N)", outputs={"__out": dace.Memlet("B[__i]")}, external_edges=True, ) - stateF = sdfg.add_state(is_start_block=False) + stateF = sdfg.add_state("stateF", is_start_block=False) stateF.add_mapped_tasklet( "Fcomp", map_ranges={"__i": "0:N"}, inputs={"__in": dace.Memlet("A[__i]")}, - code="__out = __in + 2 * N", + code="__out = (__in + 3 * N) if lim_area else (__in - 4 * N)", outputs={"__out": dace.Memlet("B[__i]")}, external_edges=True, ) - stateJ = sdfg.add_state(is_start_block=False) + stateJ = sdfg.add_state("stateJ", is_start_block=False) sdfg.add_edge(state, stateT, dace.InterstateEdge(condition="lim_area")) sdfg.add_edge(state, stateF, dace.InterstateEdge(condition="not lim_area")) sdfg.add_edge(stateT, stateJ, dace.InterstateEdge()) @@ -193,16 +102,72 @@ def _make_not_wrapped_sdfg() -> dace.SDFG: return sdfg -def _make_wrapped_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: +def make_single_state_sdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("single_state_sdfg")) + state = sdfg.add_state(is_start_block=True) + sdfg.add_symbol("N", dace.int64) + sdfg.add_symbol("lim_area", dace.bool_) + for name in "AB": + sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False) + + state.add_mapped_tasklet( + "PreComp", + map_ranges={"__i": "0:N"}, + inputs={"__in": dace.Memlet("A[__i]")}, + code="__out = (__in + N) if lim_area else (__in - N)", + outputs={"__out": dace.Memlet("B[__i]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def make_single_state_with_two_maps_sdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("single_state_sdfg")) + state = sdfg.add_state(is_start_block=True) + sdfg.add_symbol("N", dace.int64) + sdfg.add_symbol("lim_area", dace.bool_) + for name in "ABT": + sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False) + sdfg.arrays["T"].transient = True + + T = state.add_access("T") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:N"}, + inputs={"__in": dace.Memlet("A[__i]")}, + code="__out = (__in + N) if lim_area else (__in - N)", + outputs={"__out": dace.Memlet("T[__i]")}, + output_nodes={T}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i": "0:N"}, + inputs={"__in": dace.Memlet("T[__i]")}, + code="__out = (__in + 7 * N) if lim_area else (__in - 4 * N)", + outputs={"__out": dace.Memlet("B[__i]")}, + input_nodes={T}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def make_wrapped_sdfg( + single_state: bool, +) -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: sdfg = dace.SDFG(util.unique_name("wrapped_sdfg")) - state = sdfg.add_state("WRAPPED", is_start_block=True) + state = sdfg.add_state("wrap_state", is_start_block=True) sdfg.add_symbol("lim_area", dace.bool_) - sdfg.add_symbol("N", dace.bool_) + sdfg.add_symbol("N", dace.int64) for name in "AB": sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False) + inner_sdfg = make_single_state_sdfg() if single_state else make_multi_state_sdfg() nsdfg = state.add_nested_sdfg( - sdfg=_make_not_wrapped_sdfg(), + sdfg=inner_sdfg, parent=sdfg, inputs={"A"}, outputs={"B"}, @@ -218,84 +183,178 @@ def _make_wrapped_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: return sdfg, nsdfg -def test_constant_substitution_not_wrapped_sdfg(): - sdfg: dace.SDFG = _make_not_wrapped_sdfg() - assert sdfg.number_of_nodes() > 1 - assert sdfg.free_symbols == {"N", "lim_area"} - map_entries_old: list[dace_nodes.MapEntry] = util.count_nodes( - sdfg, - node_type=dace_nodes.MapEntry, - return_nodes=True, - ) - assert any(str(map_entry.map.range[0][1] + 1) == "N" for map_entry in map_entries_old) - assert all( - node.desc(sdfg).shape == ("N",) - for node in sdfg.all_nodes_recursive() - if isinstance(node, dace_nodes.AccessNode) - ) +def test_nested_sdfg_with_single_state(): + sdfg, nested_sdfg = make_wrapped_sdfg(single_state=True) + assert _check_shapes(sdfg, ("N",)) + assert _check_shapes(nested_sdfg.sdfg, ("N",)) + assert _check_maps(nested_sdfg.sdfg, "N") + assert _check_tasklets(nested_sdfg.sdfg, expected_symbols={"N", "lim_area"}) + + repl = {"N": 10, "lim_area": True} + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, repl) + + assert _check_shapes(sdfg, (10,), to_string=False) + assert _check_shapes(nested_sdfg.sdfg, ("10",)) + assert _check_maps(nested_sdfg.sdfg, "10") + assert _check_tasklets(nested_sdfg.sdfg, forbidden_symbols={"N", "lim_area"}) + assert len(nested_sdfg.symbol_mapping) == 0 + + +def test_nested_sdfg_with_multiple_states(): + sdfg, nested_sdfg = make_wrapped_sdfg(single_state=False) + assert _check_shapes(sdfg, ("N",)) + assert _check_shapes(nested_sdfg.sdfg, ("N",)) + assert _check_maps(nested_sdfg.sdfg, "N") + assert _check_tasklets(nested_sdfg.sdfg, expected_symbols={"N", "lim_area"}) + + repl = {"N": 10, "lim_area": True} + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, repl) + + # Due to a bug in DaCe, see `gtx_transformations.gt_substitute_compiletime_symbols()` + # we can not inspect the nested SDFG, since the function has to call simplify. + # For that reason we currently check if the nested SDFG was inlineed and the + # whole thing has collapsed to a single state with a map. + # TODO(phimuell): Reactivate after the bug has been fixed. + # assert _check_shapes(nested_sdfg.sdfg, ("10",)) + # assert _check_maps(nested_sdfg.sdfg, "10") + # assert _check_tasklets(nested_sdfg.sdfg, forbidden_symbols={"N", "lim_area"}) + # assert len(nested_sdfg.symbol_mapping) == 0 + # assert _check_shapes(sdfg, (10,), to_string=False) - gtx_transformations.gt_substitute_compiletime_symbols( - sdfg, - repl={"N": 10, "lim_area": True}, - validate=True, - validate_all=True, - ) + assert sdfg.number_of_nodes() == 1 + assert util.count_nodes(sdfg, node_type=dace_nodes.NestedSDFG) == 0 + assert _check_shapes(sdfg, ("10",)) + assert _check_maps(sdfg, "10") + assert _check_tasklets(sdfg, forbidden_symbols={"N", "lim_area"}) + + +def test_single_state_top_sdfg(): + # This test works because everything is inside a single state. + sdfg = make_single_state_sdfg() + assert sdfg.number_of_nodes() == 1 + + assert _check_maps(sdfg, "N") + assert _check_shapes(sdfg, ("N",)) + assert _check_tasklets(sdfg, expected_symbols={"N", "lim_area"}) + + repl = {"N": 10, "lim_area": True} + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, repl) + + assert _check_maps(sdfg, "10") + assert _check_shapes(sdfg, (10,), to_string=False) + assert _check_tasklets(sdfg, forbidden_symbols={"N", "lim_area"}) + + +def test_single_state_with_two_maps(): + # This test works because everything is inside a single state. + sdfg = make_single_state_with_two_maps_sdfg() + assert sdfg.number_of_nodes() == 1 + + assert _check_maps(sdfg, "N") + assert _check_shapes(sdfg, ("N",)) + assert _check_tasklets(sdfg, expected_symbols={"N", "lim_area"}) + + repl = {"N": 10, "lim_area": True} + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, repl) + + assert _check_maps(sdfg, "10") + assert _check_shapes(sdfg, (10,), to_string=False) + assert _check_tasklets(sdfg, forbidden_symbols={"N", "lim_area"}) + + +def test_multi_state_top_sdfg(): + sdfg = make_multi_state_sdfg() assert sdfg.number_of_nodes() == 4 - assert len(sdfg.free_symbols) == 0 - map_entries: list[dace_nodes.MapEntry] = util.count_nodes( - sdfg, - node_type=dace_nodes.MapEntry, - return_nodes=True, - ) - assert len(map_entries) == 2 - assert any(str(map_entry.map.range[0][1]) == "9" for map_entry in map_entries) - assert all( - node.desc(sdfg).shape == (10,) - for node in sdfg.all_nodes_recursive() - if isinstance(node, dace_nodes.AccessNode) - ) + start_state: dace.SDFGState = sdfg.start_state + assert start_state.label == "stateS" + assert all("lim_area" in edge.data.free_symbols for edge in sdfg.out_edges(start_state)) + + assert _check_maps(sdfg, "N") + assert _check_shapes(sdfg, ("N",)) + assert _check_tasklets(sdfg, expected_symbols={"N", "lim_area"}) + + repl = {"N": 10, "lim_area": True} + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, repl) -def test_constant_substitution_wrapped_sdfg(): - sdfg, nsdfg = _make_wrapped_sdfg() + assert _check_maps(sdfg, "10") + assert _check_shapes(sdfg, (10,), to_string=False) + assert _check_tasklets(sdfg, forbidden_symbols={"N", "lim_area"}) + + # Because of the bug in DaCe, see `gtx_transformations.gt_substitute_compiletime_symbols()` + # we can not inspect the condition on the edges, because simplify has been called. + # Thus for the time being we will just test if we are left with one state instead. + # TODO(phimuell): reactivate once the bug has been solved. + # assert not any("lim_area" in edge.data.free_symbols for edge in sdfg.out_edges(start_state)) assert sdfg.number_of_nodes() == 1 - assert sdfg.free_symbols == {"N", "lim_area"} - assert util.count_nodes(sdfg, dace_nodes.NestedSDFG) == 1 - map_entries_old: list[dace_nodes.MapEntry] = util.count_nodes( - nsdfg.sdfg, - dace_nodes.MapEntry, - return_nodes=True, + +def test_single_state_nested_with_top_map(): + sdfg, nested_sdfg = make_wrapped_sdfg(single_state=True) + assert sdfg.number_of_nodes() == 1 + state: dace.SDFGState = list(sdfg.states())[0] + + sdfg.add_datadesc("new_input", sdfg.arrays["A"].clone()) + sdfg.arrays["A"].transient = True + A: dace_nodes.AccessNode = next( + iter(dnode for dnode in state.data_nodes() if dnode.data == "A") ) - assert any(str(map_entry.map.range[0][1] + 1) == "N" for map_entry in map_entries_old) - assert all( - node.desc(nsdfg.sdfg).shape == ("N",) - for node in sdfg.all_nodes_recursive() - if isinstance(node, dace_nodes.AccessNode) + state.add_mapped_tasklet( + "compOutside", + map_ranges={"__i": "0:N"}, + inputs={"__in": dace.Memlet("new_input[__i]")}, + code="__out = (__in + 10 * N) if lim_area else (__in - 14 * N)", + outputs={"__out": dace.Memlet("A[__i]")}, + output_nodes={A}, + external_edges=True, ) + sdfg.validate() - sdfg.view() - gtx_transformations.gt_substitute_compiletime_symbols( - sdfg, - repl={"N": "10", "lim_area": 1}, - validate=True, - validate_all=True, - ) + assert _check_maps(sdfg, "N") + assert _check_shapes(sdfg, ("N",)) + assert _check_tasklets(sdfg, expected_symbols={"N", "lim_area"}) + assert _check_shapes(nested_sdfg.sdfg, ("N",)) + assert _check_maps(nested_sdfg.sdfg, "N") + assert _check_tasklets(nested_sdfg.sdfg, expected_symbols={"N", "lim_area"}) - assert sdfg.number_of_nodes() == 1 - assert util.count_nodes(sdfg, dace_nodes.NestedSDFG) == 1 - assert len(sdfg.free_symbols) == 0 - for nsdfg in sdfg.all_sdfgs_recursive(): - assert all( - node.desc(nsdfg).shape == (10,) - for node in sdfg.all_nodes_recursive() - if isinstance(node, dace_nodes.AccessNode) - ) + repl = {"N": 10, "lim_area": True} + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, repl) - map_entries: list[dace_nodes.MapEntry] = util.count_nodes( - nsdfg.sdfg, - dace_nodes.MapEntry, - return_nodes=True, + assert _check_maps(sdfg, "10") + assert _check_shapes(sdfg, (10,), to_string=False) + assert _check_tasklets(sdfg, forbidden_symbols={"N", "lim_area"}) + assert _check_shapes(nested_sdfg.sdfg, ("10",)) + assert _check_maps(nested_sdfg.sdfg, "10") + assert _check_tasklets(nested_sdfg.sdfg, forbidden_symbols={"N", "lim_area"}) + assert len(nested_sdfg.symbol_mapping) == 0 + + +@pytest.mark.xfail(reason="AccessNode replacement can not be done yet.") +def test_replace_access_node(): + sdfg = dace.SDFG(util.unique_name("replaced_access_node")) + state = sdfg.add_state(is_start_block=True) + sdfg.add_symbol("N", dace.int64) + for name in "AB": + sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False) + sdfg.add_scalar("S", dtype=dace.float64, transient=False) + + tsklt, me, mx = state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:N"}, + inputs={ + "__in1": dace.Memlet("A[__i0]"), + "__in2": dace.Memlet("S[0]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("B[__i0]")}, + external_edges=True, ) - assert len(map_entries) == 2 - assert any(str(map_entry.map.range[0][1]) == "9" for map_entry in map_entries) + sdfg.validate() + + repl = {"N": 10, "S": 10} + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, repl) + + assert len(list(dnode for dnode in state.data_nodes() if dnode.data == "S")) == 0 + assert _check_maps(sdfg, "10") + assert _check_shapes(sdfg, (10,), to_string=False) + assert _check_tasklets(sdfg, forbidden_symbols={"N", "lim_area"})