diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 30640306cd..47c180aff6 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2620,7 +2620,7 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]: """ return [] - def replace_meta_accesses(self, replacements: dict) -> None: + def replace_meta_accesses(self, replacements: Dict[str, str]) -> None: """ Replace accesses to specific data containers in reads or writes performed by the control flow region itself in meta accesses, such as in condition checks for conditional blocks or in loop conditions for loops, etc. @@ -3331,6 +3331,8 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]: return read_memlets def replace_meta_accesses(self, replacements): + if self.loop_variable in replacements: + self.loop_variable = replacements[self.loop_variable] replace_in_codeblock(self.loop_condition, replacements) if self.init_statement: replace_in_codeblock(self.init_statement, replacements) diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index a01d903a1d..c501a769ff 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -111,6 +111,8 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: if node.code_exit.language != dtypes.Language.Python: result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), node.ignored_symbols) + else: + result |= block.used_symbols(all_symbols=True, with_contents=False) for e in sdfg.all_interstate_edges(): result |= e.data.free_symbols