From f566e9adacd88e1aaf41f07fce81aa6529b2cce2 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 22 Jan 2025 09:51:28 +0100 Subject: [PATCH] Fix scalar2symbol not being correctly adapted to CFRs (#1889) --- dace/transformation/passes/scalar_to_symbol.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 50bf5e9fb2..24b38274ae 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -587,7 +587,7 @@ def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): # Descend recursively to remove scalar remove_scalar_reads(dst.sdfg, {e.dst_conn: tmp_symname}) - for ise in dst.sdfg.edges(): + for ise in dst.sdfg.all_interstate_edges(): ise.data.replace(e.dst_conn, tmp_symname) # Remove subscript occurrences as well for aname, aval in ise.data.assignments.items(): @@ -595,6 +595,12 @@ def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): vast = astutils.RemoveSubscripts({tmp_symname}).visit(vast) ise.data.assignments[aname] = astutils.unparse(vast) ise.data.replace(tmp_symname + '[0]', tmp_symname) + promo = TaskletPromoterDict({e.dst_conn: tmp_symname}) + for reg in dst.sdfg.all_control_flow_regions(): + meta_codes = reg.get_meta_codeblocks() + for cd in meta_codes: + for stmt in cd.code: + promo.visit(stmt) # Set symbol mapping dst.sdfg.remove_data(e.dst_conn, validate=False)