From bf547145b9e65d8fb5afb3abbc5dc1cd948c6034 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 12 Feb 2025 17:09:18 +0100 Subject: [PATCH] Fix DSE pass for conditional block with single branch and unconditional execution (#1930) This PR addresses #1928 by fixing some code in `DeadStateElimination` (DSE) pass. This code was supposed to handle exactly the case reported in the issue, but the check for single branch with unconditional execution was not entirely correct. Test case added. Includes a change in fpga tests to address a CI error not related to this change: https://github.com/spcl/dace/actions/runs/13242670532/job/36961472662?pr=1930 --- .../passes/dead_state_elimination.py | 23 ++++---- tests/fpga/vec_sum_test.py | 3 +- tests/passes/dead_code_elimination_test.py | 55 +++++++++++++++++++ 3 files changed, 68 insertions(+), 13 deletions(-) diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index 80ecaa49fb..24b38b7e52 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -65,16 +65,17 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters for _, b in dead_branches: result.add(b) node.remove_branch(b) - # If only an 'else' is left over, inline it. - if len(node.branches) == 1 and node.branches[0][0] is None: - branch = node.branches[0][1] - node.parent_graph.add_node(branch) - for ie in cfg.in_edges(node): - cfg.add_edge(ie.src, branch, ie.data) - for oe in cfg.out_edges(node): - cfg.add_edge(branch, oe.dst, oe.data) - result.add(node) - cfg.remove_node(node) + # If only one branch is left, and it is unconditionally executed, inline it. + if len(node.branches) == 1: + cond, branch = node.branches[0] + if cond is None or self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), sdfg): + node.parent_graph.add_node(branch) + for ie in cfg.in_edges(node): + cfg.add_edge(ie.src, branch, ie.data) + for oe in cfg.out_edges(node): + cfg.add_edge(branch, oe.dst, oe.data) + result.add(node) + cfg.remove_node(node) else: result.add(node) is_start = node is cfg.start_block @@ -170,7 +171,7 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, raise InvalidSDFGNodeError('Conditional block detected, where else branch is not the last branch') break # If an unconditional branch is found, ignore all other branches that follow this one. - if cond.as_string.strip() == '1' or self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + if self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): unconditional = branch break if unconditional is not None: diff --git a/tests/fpga/vec_sum_test.py b/tests/fpga/vec_sum_test.py index 791ba80e5d..2ce1a5fe97 100644 --- a/tests/fpga/vec_sum_test.py +++ b/tests/fpga/vec_sum_test.py @@ -34,8 +34,7 @@ def sum(i: _[0:N]): X = rng.random(n, dtype=np.float32) Y = rng.random(n, dtype=np.float32) Z = rng.random(n, dtype=np.float32) - ref = np.empty(n, dtype=np.float32) - ref[:] = X + Y + Z + ref = X + Y + Z sdfg = vec_sum.to_sdfg() diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index 231ccac84f..5014fe8073 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -380,6 +380,60 @@ def test_dce_add_type_hint_of_variable(dtype): assert np.all(out == np.where(cond, true_value, false_value)) +def test_prune_single_branch_conditional_block(): + sdfg = dace.SDFG("conditional_sdfg") + + for name in "abc": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["b"].transient = True + + first_state = sdfg.add_state("first_state") + first_state.add_mapped_tasklet( + "first_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("b[__i0]")}, + external_edges=True, + ) + + # create states inside the nested SDFG for the if-branches + if_region = dace.sdfg.state.ConditionalBlock("if") + sdfg.add_node(if_region) + sdfg.add_edge( + first_state, + if_region, + dace.InterstateEdge() + ) + + then_body = dace.sdfg.state.ControlFlowRegion( + "then_body", + sdfg=sdfg + ) + then_state = then_body.add_state("true_branch", is_start_block=True) + if_region.add_branch( + dace.sdfg.state.CodeBlock("True"), + then_body + ) + then_state.add_mapped_tasklet( + "second_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("b[__i0]")}, + code="__out = __in1 + 1.0", + outputs={"__out": dace.Memlet("c[__i0]")}, + external_edges=True, + ) + sdfg.validate() + res = DeadStateElimination().apply_pass(sdfg, {}) + assert res and len(res) == 1 + assert sdfg.out_edges(first_state)[0].dst == then_body + + if __name__ == '__main__': test_dse_simple() test_dse_unconditional() @@ -400,3 +454,4 @@ def test_dce_add_type_hint_of_variable(dtype): test_dce_add_type_hint_of_variable(dace.float64) test_dce_add_type_hint_of_variable(dace.bool) test_dce_add_type_hint_of_variable(np.float64) + test_prune_single_branch_conditional_block()