diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 50bf5e9fb2..24b38274ae 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -587,7 +587,7 @@ def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): # Descend recursively to remove scalar remove_scalar_reads(dst.sdfg, {e.dst_conn: tmp_symname}) - for ise in dst.sdfg.edges(): + for ise in dst.sdfg.all_interstate_edges(): ise.data.replace(e.dst_conn, tmp_symname) # Remove subscript occurrences as well for aname, aval in ise.data.assignments.items(): @@ -595,6 +595,12 @@ def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): vast = astutils.RemoveSubscripts({tmp_symname}).visit(vast) ise.data.assignments[aname] = astutils.unparse(vast) ise.data.replace(tmp_symname + '[0]', tmp_symname) + promo = TaskletPromoterDict({e.dst_conn: tmp_symname}) + for reg in dst.sdfg.all_control_flow_regions(): + meta_codes = reg.get_meta_codeblocks() + for cd in meta_codes: + for stmt in cd.code: + promo.visit(stmt) # Set symbol mapping dst.sdfg.remove_data(e.dst_conn, validate=False)