diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 4a141aef12..1bd343ecfb 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -102,6 +102,9 @@ def _nested_arrays_from_json(obj, context=None): def _replace_dict_keys(d, old, new): + if old == new: + warnings.warn(f"Trying to replace key with the same name {old} ... skipping.") + return if old in d: if new in d: warnings.warn('"%s" already exists in SDFG' % new) @@ -734,6 +737,12 @@ def replace_dict(self, :param replace_in_graph: Whether to replace in SDFG nodes / edges. :param replace_keys: If True, replaces in SDFG property names (e.g., array, symbol, and constant names). """ + + repldict = {k: v for k, v in repldict.items() if k != v} + if symrepl: + symrepl = {k: v for k, v in symrepl.items() if str(k) != str(v)} + + symrepl = symrepl or { symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v for k, v in repldict.items() diff --git a/tests/sdfg/interstate_assignment_test.py b/tests/sdfg/interstate_assignment_test.py new file mode 100644 index 0000000000..31efe3f63c --- /dev/null +++ b/tests/sdfg/interstate_assignment_test.py @@ -0,0 +1,36 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np + + +def test_key_replacement_same_name(): + + sdfg = dace.SDFG('key_replacement_same_name') + sdfg.add_array('inp', [1], dace.int32) + sdfg.add_array('out', [1], dace.int32) + + first = sdfg.add_state('first_state') + second = sdfg.add_state('second_state') + edge = sdfg.add_edge(first, second, dace.InterstateEdge(assignments={'s': 'inp[0]'})) + + task = second.add_tasklet('t', {}, {'__out'}, '__out = s') + access = second.add_access('out') + second.add_edge(task, '__out', access, None, dace.Memlet('out[0]')) + + sdfg.replace('s', 's') + assert 's' in edge.data.assignments + sdfg.replace_dict({'s': 's'}) + assert 's' in edge.data.assignments + + rng = np.random.default_rng() + inp = rng.integers(1, 100, 1) + inp = np.array(inp, dtype=np.int32) + out = np.zeros([1], dtype=np.int32) + + sdfg(inp=inp, out=out) + assert out[0] == inp[0] + + +if __name__ == '__main__': + test_key_replacement_same_name()