Skip to content

Commit

Permalink
Fixes replacement of dict keys (spcl#1845)
Browse files Browse the repository at this point in the history
If the internal method `SDFG._replace_dict_keys` is called with the same
`old` and `new` keys, then the method erroneously removes the entry for
the specific key completely. Such a case occurs when calling
`SDFG.replace_dict` with a dictionary that contains identity replacement
matches.

This PR addresses the issue by:
- [x] Adding a check inside `SDFG._replace_dict_keys` and skipping if
`old == new`. A warning is also thrown for debugging purposes.
- [x] Removing identity replacement matches from `repldict` and
`symrepl` in `SDFG.replace_dict`.
  • Loading branch information
alexnick83 authored Jan 7, 2025
1 parent 43883ea commit 6254c0f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
9 changes: 9 additions & 0 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def _nested_arrays_from_json(obj, context=None):


def _replace_dict_keys(d, old, new):
if old == new:
warnings.warn(f"Trying to replace key with the same name {old} ... skipping.")
return
if old in d:
if new in d:
warnings.warn('"%s" already exists in SDFG' % new)
Expand Down Expand Up @@ -734,6 +737,12 @@ def replace_dict(self,
:param replace_in_graph: Whether to replace in SDFG nodes / edges.
:param replace_keys: If True, replaces in SDFG property names (e.g., array, symbol, and constant names).
"""

repldict = {k: v for k, v in repldict.items() if k != v}
if symrepl:
symrepl = {k: v for k, v in symrepl.items() if str(k) != str(v)}


symrepl = symrepl or {
symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v
for k, v in repldict.items()
Expand Down
36 changes: 36 additions & 0 deletions tests/sdfg/interstate_assignment_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved.

import dace
import numpy as np


def test_key_replacement_same_name():

sdfg = dace.SDFG('key_replacement_same_name')
sdfg.add_array('inp', [1], dace.int32)
sdfg.add_array('out', [1], dace.int32)

first = sdfg.add_state('first_state')
second = sdfg.add_state('second_state')
edge = sdfg.add_edge(first, second, dace.InterstateEdge(assignments={'s': 'inp[0]'}))

task = second.add_tasklet('t', {}, {'__out'}, '__out = s')
access = second.add_access('out')
second.add_edge(task, '__out', access, None, dace.Memlet('out[0]'))

sdfg.replace('s', 's')
assert 's' in edge.data.assignments
sdfg.replace_dict({'s': 's'})
assert 's' in edge.data.assignments

rng = np.random.default_rng()
inp = rng.integers(1, 100, 1)
inp = np.array(inp, dtype=np.int32)
out = np.zeros([1], dtype=np.int32)

sdfg(inp=inp, out=out)
assert out[0] == inp[0]


if __name__ == '__main__':
test_key_replacement_same_name()

0 comments on commit 6254c0f

Please sign in to comment.