Skip to content

Commit

Permalink
Fix incorrect offsets during assignments with broadcasts (spcl#1875)
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad authored Jan 22, 2025
1 parent f566e9a commit 855fc27
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
13 changes: 11 additions & 2 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2725,8 +2725,17 @@ def _add_assignment(self,

fake_subset = dace.subsets.Range(missing_dimensions + op_dimensions)

# use this fake subset to calculate the offset
fake_subset.offset(squeezed, True)
# Use this fake subset to calculate the offset. Constant indices are ignored, as they do not depend
# on the broadcasting operation.
offset_indices_to_ignore = set()
for i, idx in enumerate(inp_idx):
if not symbolic.issymbolic(pystr_to_symbolic(idx)):
offset_indices_to_ignore.add(i)
fake_subset_offs_indices = []
for i in range(len(fake_subset)):
if i not in offset_indices_to_ignore:
fake_subset_offs_indices.append(i)
fake_subset.offset(squeezed, True, indices=fake_subset_offs_indices)

# we access the inp subset using the computed offset
# since the inp_subset may be missing leading dimensions, we reverse-zip-reverse
Expand Down
10 changes: 6 additions & 4 deletions tests/numpy/split_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def test_dsplit_4d():
return a, b, c


def test_compiletime_split():
@pytest.mark.parametrize('out_idx', [0, 1])
def test_compiletime_split(out_idx):

@dace.program
def tester(x, y, in_indices: dace.compiletime, out_index: dace.compiletime):
Expand All @@ -138,9 +139,9 @@ def tester(x, y, in_indices: dace.compiletime, out_index: dace.compiletime):

x = np.random.rand(1000, 8)
y = np.zeros_like(x)
tester(x, y, (1, 2, 3, 4, 5, 7), 0)
tester(x, y, (1, 2, 3, 4, 5, 7), out_idx)
ref = np.zeros_like(y)
tester.f(x, ref, (1, 2, 3, 4, 5, 7), 0)
tester.f(x, ref, (1, 2, 3, 4, 5, 7), out_idx)

assert np.allclose(y, ref)

Expand All @@ -158,4 +159,5 @@ def tester(x, y, in_indices: dace.compiletime, out_index: dace.compiletime):
test_vsplit()
test_hsplit()
test_dsplit_4d()
test_compiletime_split()
test_compiletime_split(0)
test_compiletime_split(1)

0 comments on commit 855fc27

Please sign in to comment.