From 6254c0fc0cb09f2e8249b0aef04ef81ad7375aa3 Mon Sep 17 00:00:00 2001 From: alexnick83 <31545860+alexnick83@users.noreply.github.com> Date: Tue, 7 Jan 2025 14:47:31 +0100 Subject: [PATCH] Fixes replacement of dict keys (#1845) If the internal method `SDFG._replace_dict_keys` is called with the same `old` and `new` keys, then the method erroneously removes the entry for the specific key completely. Such a case occurs when calling `SDFG.replace_dict` with a dictionary that contains identity replacement matches. This PR addresses the issue by: - [x] Adding a check inside `SDFG._replace_dict_keys` and skipping if `old == new`. A warning is also thrown for debugging purposes. - [x] Removing identity replacement matches from `repldict` and `symrepl` in `SDFG.replace_dict`. --- dace/sdfg/sdfg.py | 9 ++++++ tests/sdfg/interstate_assignment_test.py | 36 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 tests/sdfg/interstate_assignment_test.py diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 4a141aef12..1bd343ecfb 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -102,6 +102,9 @@ def _nested_arrays_from_json(obj, context=None): def _replace_dict_keys(d, old, new): + if old == new: + warnings.warn(f"Trying to replace key with the same name {old} ... skipping.") + return if old in d: if new in d: warnings.warn('"%s" already exists in SDFG' % new) @@ -734,6 +737,12 @@ def replace_dict(self, :param replace_in_graph: Whether to replace in SDFG nodes / edges. :param replace_keys: If True, replaces in SDFG property names (e.g., array, symbol, and constant names). """ + + repldict = {k: v for k, v in repldict.items() if k != v} + if symrepl: + symrepl = {k: v for k, v in symrepl.items() if str(k) != str(v)} + + symrepl = symrepl or { symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v for k, v in repldict.items() diff --git a/tests/sdfg/interstate_assignment_test.py b/tests/sdfg/interstate_assignment_test.py new file mode 100644 index 0000000000..31efe3f63c --- /dev/null +++ b/tests/sdfg/interstate_assignment_test.py @@ -0,0 +1,36 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np + + +def test_key_replacement_same_name(): + + sdfg = dace.SDFG('key_replacement_same_name') + sdfg.add_array('inp', [1], dace.int32) + sdfg.add_array('out', [1], dace.int32) + + first = sdfg.add_state('first_state') + second = sdfg.add_state('second_state') + edge = sdfg.add_edge(first, second, dace.InterstateEdge(assignments={'s': 'inp[0]'})) + + task = second.add_tasklet('t', {}, {'__out'}, '__out = s') + access = second.add_access('out') + second.add_edge(task, '__out', access, None, dace.Memlet('out[0]')) + + sdfg.replace('s', 's') + assert 's' in edge.data.assignments + sdfg.replace_dict({'s': 's'}) + assert 's' in edge.data.assignments + + rng = np.random.default_rng() + inp = rng.integers(1, 100, 1) + inp = np.array(inp, dtype=np.int32) + out = np.zeros([1], dtype=np.int32) + + sdfg(inp=inp, out=out) + assert out[0] == inp[0] + + +if __name__ == '__main__': + test_key_replacement_same_name()