Skip to content

Commit

Permalink
Fix DSE pass for conditional block with single branch and uncondition…
Browse files Browse the repository at this point in the history
…al execution (spcl#1930)

This PR addresses spcl#1928 by fixing some code in `DeadStateElimination`
(DSE) pass. This code was supposed to handle exactly the case reported
in the issue, but the check for single branch with unconditional
execution was not entirely correct. Test case added.


Includes a change in fpga tests to address a CI error not related to
this change:

https://github.com/spcl/dace/actions/runs/13242670532/job/36961472662?pr=1930
  • Loading branch information
edopao authored Feb 12, 2025
1 parent 5097d6f commit bf54714
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 13 deletions.
23 changes: 12 additions & 11 deletions dace/transformation/passes/dead_state_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,17 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters
for _, b in dead_branches:
result.add(b)
node.remove_branch(b)
# If only an 'else' is left over, inline it.
if len(node.branches) == 1 and node.branches[0][0] is None:
branch = node.branches[0][1]
node.parent_graph.add_node(branch)
for ie in cfg.in_edges(node):
cfg.add_edge(ie.src, branch, ie.data)
for oe in cfg.out_edges(node):
cfg.add_edge(branch, oe.dst, oe.data)
result.add(node)
cfg.remove_node(node)
# If only one branch is left, and it is unconditionally executed, inline it.
if len(node.branches) == 1:
cond, branch = node.branches[0]
if cond is None or self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), sdfg):
node.parent_graph.add_node(branch)
for ie in cfg.in_edges(node):
cfg.add_edge(ie.src, branch, ie.data)
for oe in cfg.out_edges(node):
cfg.add_edge(branch, oe.dst, oe.data)
result.add(node)
cfg.remove_node(node)
else:
result.add(node)
is_start = node is cfg.start_block
Expand Down Expand Up @@ -170,7 +171,7 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock,
raise InvalidSDFGNodeError('Conditional block detected, where else branch is not the last branch')
break
# If an unconditional branch is found, ignore all other branches that follow this one.
if cond.as_string.strip() == '1' or self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg):
if self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg):
unconditional = branch
break
if unconditional is not None:
Expand Down
3 changes: 1 addition & 2 deletions tests/fpga/vec_sum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def sum(i: _[0:N]):
X = rng.random(n, dtype=np.float32)
Y = rng.random(n, dtype=np.float32)
Z = rng.random(n, dtype=np.float32)
ref = np.empty(n, dtype=np.float32)
ref[:] = X + Y + Z
ref = X + Y + Z

sdfg = vec_sum.to_sdfg()

Expand Down
55 changes: 55 additions & 0 deletions tests/passes/dead_code_elimination_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,60 @@ def test_dce_add_type_hint_of_variable(dtype):
assert np.all(out == np.where(cond, true_value, false_value))


def test_prune_single_branch_conditional_block():
sdfg = dace.SDFG("conditional_sdfg")

for name in "abc":
sdfg.add_array(
name,
shape=(10,),
dtype=dace.float64,
transient=False,
)
sdfg.arrays["b"].transient = True

first_state = sdfg.add_state("first_state")
first_state.add_mapped_tasklet(
"first_comp",
map_ranges={"__i0": "0:10"},
inputs={"__in1": dace.Memlet("a[__i0]")},
code="__out = __in1 + 10.0",
outputs={"__out": dace.Memlet("b[__i0]")},
external_edges=True,
)

# create states inside the nested SDFG for the if-branches
if_region = dace.sdfg.state.ConditionalBlock("if")
sdfg.add_node(if_region)
sdfg.add_edge(
first_state,
if_region,
dace.InterstateEdge()
)

then_body = dace.sdfg.state.ControlFlowRegion(
"then_body",
sdfg=sdfg
)
then_state = then_body.add_state("true_branch", is_start_block=True)
if_region.add_branch(
dace.sdfg.state.CodeBlock("True"),
then_body
)
then_state.add_mapped_tasklet(
"second_comp",
map_ranges={"__i0": "0:10"},
inputs={"__in1": dace.Memlet("b[__i0]")},
code="__out = __in1 + 1.0",
outputs={"__out": dace.Memlet("c[__i0]")},
external_edges=True,
)
sdfg.validate()
res = DeadStateElimination().apply_pass(sdfg, {})
assert res and len(res) == 1
assert sdfg.out_edges(first_state)[0].dst == then_body


if __name__ == '__main__':
test_dse_simple()
test_dse_unconditional()
Expand All @@ -400,3 +454,4 @@ def test_dce_add_type_hint_of_variable(dtype):
test_dce_add_type_hint_of_variable(dace.float64)
test_dce_add_type_hint_of_variable(dace.bool)
test_dce_add_type_hint_of_variable(np.float64)
test_prune_single_branch_conditional_block()

0 comments on commit bf54714

Please sign in to comment.