Skip to content

Commit

Permalink
Fix two bugs in state fusion and prune empty conditional branches (sp…
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad authored Dec 18, 2024
1 parent a517699 commit 3281661
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion dace/transformation/interstate/state_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,15 @@ def apply(self, _, sdfg):

# Merge common (data) nodes
merged_nodes = set()
removed_nodes = set()
for node in second_mid:

# merge only top level nodes, skip everything else
if node not in top2:
continue

candidates = [x for x in order if x.data == node.data and x in top and x not in merged_nodes]
candidates = [x for x in order
if x.data == node.data and x in top and x not in merged_nodes and x not in removed_nodes]
source_node = first_state.in_degree(node) == 0

# If not source node, try to connect every memlet-intersecting candidate
Expand All @@ -552,6 +554,7 @@ def apply(self, _, sdfg):
sdutil.change_edge_src(first_state, cand, node)
sdutil.change_edge_dest(first_state, cand, node)
first_state.remove_node(cand)
removed_nodes.add(cand)
continue

if len(candidates) == 0:
Expand All @@ -571,6 +574,7 @@ def apply(self, _, sdfg):
sdutil.change_edge_src(first_state, node, n)
sdutil.change_edge_dest(first_state, node, n)
first_state.remove_node(node)
removed_nodes.add(node)
merged_nodes.add(n)

# Redirect edges and remove second state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def apply(self, region: ControlFlowRegion, _) -> Optional[int]:
if len(region.branches) == 0:
# The conditional has become entirely empty, remove it.
replacement_node_before = region.parent_graph.add_state_before(region)
replacement_node_after = region.parent_graph.add_state_before(region)
replacement_node_after = region.parent_graph.add_state_after(region)
region.parent_graph.add_edge(replacement_node_before, replacement_node_after, InterstateEdge())
region.parent_graph.remove_node(region)

Expand Down

0 comments on commit 3281661

Please sign in to comment.