Skip to content

Commit

Permalink
Fix read_and_write_sets() for ControlFlowRegion (spcl#1920)
Browse files Browse the repository at this point in the history
The `PruneConnectors` transformation relies on `read_and_write_sets()`
to identify connectors on a nested SDFG that refer to unused data
containers. With the introduction of `ControlFlowRegion` nodes, the data
containers can now be accessed by symbolic expressions used as
conditions in such nodes. This case was not considered in baseline.

---------

Co-authored-by: Philip Mueller <philip.mueller@cscs.ch>
  • Loading branch information
edopao and philip-paul-mueller authored Feb 4, 2025
1 parent 8c24a34 commit 118c131
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
12 changes: 10 additions & 2 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,12 +1376,20 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]:
read_set = set()
write_set = set()
for state in self.states():
for edge in state.parent_graph.in_edges(state):
read_set |= edge.data.free_symbols & self.arrays.keys()
# Get dictionaries of subsets read and written from each state
rs, ws = state._read_and_write_sets()
read_set |= rs.keys()
write_set |= ws.keys()

array_names = self.arrays.keys()
for edge in self.all_interstate_edges():
read_set |= edge.data.free_symbols & array_names

# By definition, data that is referenced by symbolic condition expressions
# (branching condition, loop condition, ...) is also part of the read set.
for cfr in self.all_control_flow_regions():
read_set |= cfr.used_symbols(all_symbols=True, with_contents=False) & array_names

return read_set, write_set

def arglist(self, scalars_only=False, free_symbols=None) -> Dict[str, dt.Data]:
Expand Down
44 changes: 44 additions & 0 deletions tests/transformations/prune_connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,49 @@ def test_read_write():
assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)


def test_prune_connectors_with_conditional_block():
"""
Verifies that a connector to scalar data (here 'cond') in a NestedSDFG is not removed,
when this data is only accessed by condition expressions in ControlFlowRegion nodes.
"""
sdfg = dace.SDFG('tester')
A, A_desc = sdfg.add_array('A', [4], dace.float64)
B, B_desc = sdfg.add_array('B', [4], dace.float64)
COND, COND_desc = sdfg.add_array('COND', [4], dace.bool_)
OUT, OUT_desc = sdfg.add_array('OUT', [4], dace.float64)

nsdfg = dace.SDFG('nested')
a, _ = nsdfg.add_scalar('a', A_desc.dtype)
b, _ = nsdfg.add_scalar('b', B_desc.dtype)
cond, _ = nsdfg.add_scalar('cond', COND_desc.dtype)
out, _ = nsdfg.add_scalar('out', OUT_desc.dtype)

if_region = dace.sdfg.state.ConditionalBlock("if")
nsdfg.add_node(if_region)
entry_state = nsdfg.add_state("entry", is_start_block=True)
nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge())

then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg)
a_state = then_body.add_state("true_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock(cond), then_body)
a_state.add_nedge(a_state.add_access(a), a_state.add_access(out), dace.Memlet(out))

else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg)
b_state = else_body.add_state("false_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock(f"not ({cond})"), else_body)
b_state.add_nedge(b_state.add_access(b), b_state.add_access(out), dace.Memlet(out))

state = sdfg.add_state()
nsdfg_node = state.add_nested_sdfg(nsdfg, sdfg, inputs={a, b, cond}, outputs={out})
me, mx = state.add_map('map', dict(i="0:4"))
state.add_memlet_path(state.add_access(A), me, nsdfg_node, dst_conn=a, memlet=dace.Memlet(f"{A}[i]"))
state.add_memlet_path(state.add_access(B), me, nsdfg_node, dst_conn=b, memlet=dace.Memlet(f"{B}[i]"))
state.add_memlet_path(state.add_access(COND), me, nsdfg_node, dst_conn=cond, memlet=dace.Memlet(f"{COND}[i]"))
state.add_memlet_path(nsdfg_node, mx, state.add_access(OUT), src_conn=out, memlet=dace.Memlet(f"{OUT}[i]"))

assert 0 == sdfg.apply_transformations_repeated(PruneConnectors)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--N", default=64)
Expand All @@ -431,3 +474,4 @@ def test_read_write():
test_prune_connectors_with_dependencies()
test_read_write_1()
test_read_write_2()
test_prune_connectors_with_conditional_block()

0 comments on commit 118c131

Please sign in to comment.