diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index 7e3dc6916b..0ce5b9f437 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -534,13 +534,15 @@ def apply(self, _, sdfg): # Merge common (data) nodes merged_nodes = set() + removed_nodes = set() for node in second_mid: # merge only top level nodes, skip everything else if node not in top2: continue - candidates = [x for x in order if x.data == node.data and x in top and x not in merged_nodes] + candidates = [x for x in order + if x.data == node.data and x in top and x not in merged_nodes and x not in removed_nodes] source_node = first_state.in_degree(node) == 0 # If not source node, try to connect every memlet-intersecting candidate @@ -552,6 +554,7 @@ def apply(self, _, sdfg): sdutil.change_edge_src(first_state, cand, node) sdutil.change_edge_dest(first_state, cand, node) first_state.remove_node(cand) + removed_nodes.add(cand) continue if len(candidates) == 0: @@ -571,6 +574,7 @@ def apply(self, _, sdfg): sdutil.change_edge_src(first_state, node, n) sdutil.change_edge_dest(first_state, node, n) first_state.remove_node(node) + removed_nodes.add(node) merged_nodes.add(n) # Redirect edges and remove second state diff --git a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py index d7bd397830..a492a9a65c 100644 --- a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py +++ b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py @@ -61,7 +61,7 @@ def apply(self, region: ControlFlowRegion, _) -> Optional[int]: if len(region.branches) == 0: # The conditional has become entirely empty, remove it. replacement_node_before = region.parent_graph.add_state_before(region) - replacement_node_after = region.parent_graph.add_state_before(region) + replacement_node_after = region.parent_graph.add_state_after(region) region.parent_graph.add_edge(replacement_node_before, replacement_node_after, InterstateEdge()) region.parent_graph.remove_node(region)