Skip to content

Commit

Permalink
Python: Fix scalar struct attribute writes in frontend (spcl#1847)
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad authored Jan 10, 2025
1 parent c13a362 commit 3ad9f82
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
12 changes: 6 additions & 6 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3328,14 +3328,15 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
visited_target = True
else:
true_name = defined_vars[name]
while len(tokens) > 1:
while len(tokens):
true_name = true_name + '.' + tokens.pop(0)
if true_name not in self.sdfg.arrays:
if true_name not in defined_arrays:
break
if tokens: # The non-struct remainder will be considered an attribute
attribute_name = '.'.join(tokens)
raise DaceSyntaxError(
self, target, f'Cannot assign to attribute "{attribute_name}" of variable "{true_name}"')
self, target,
f'Cannot assign to attribute "{attribute_name}" of variable "{true_name}"')

true_array = defined_arrays[true_name]

Expand Down Expand Up @@ -3515,7 +3516,6 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
if boolarr is not None and indirect_indices:
raise IndexError('Boolean array indexing cannot be combined with indirect access')


if self.nested and not new_data and not visited_target:
new_name, new_rng = self._add_write_access(name, rng, target)
# Local symbol or local data dependent
Expand Down Expand Up @@ -5032,12 +5032,12 @@ def _visit_op(self, node: Union[ast.UnaryOp, ast.BinOp, ast.BoolOp], op1: ast.AS

# Type-check operands in order to provide a clear error message
if (isinstance(operand1, dtypes.pyobject) or (isinstance(operand1, str) and operand1 in self.defined
and isinstance(self.defined[operand1].dtype, dtypes.pyobject))):
and isinstance(self.defined[operand1].dtype, dtypes.pyobject))):
raise DaceSyntaxError(
self, op1, 'Trying to operate on a callback return value with an undefined type. '
f'Please add a type hint to "{operand1}" to enable using it within the program.')
if (isinstance(operand2, dtypes.pyobject) or (isinstance(operand2, str) and operand2 in self.defined
and isinstance(self.defined[operand2].dtype, dtypes.pyobject))):
and isinstance(self.defined[operand2].dtype, dtypes.pyobject))):
raise DaceSyntaxError(
self, op2, 'Trying to operate on a callback return value with an undefined type. '
f'Please add a type hint to "{operand2}" to enable using it within the program.')
Expand Down
6 changes: 6 additions & 0 deletions dace/transformation/interstate/loop_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,15 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):

# Add NestedSDFG arrays
for name in read_set | write_set:
if '.' in name:
root_data_name = name.split('.')[0]
name = root_data_name
nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name])
nsdfg.arrays[name].transient = False
for name in unique_set:
if '.' in name:
root_data_name = name.split('.')[0]
name = root_data_name
nsdfg.arrays[name] = sdfg.arrays[name]
del sdfg.arrays[name]

Expand Down
35 changes: 35 additions & 0 deletions tests/python_frontend/structures/structure_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,40 @@ def dense_to_csr_python(A: dace.float32[M, N], B: CSR):
func(A=A, B=outB, M=tmp.shape[0], N=tmp.shape[1], nnz=tmp.nnz)


def test_write_structure_scalar():

N = dace.symbol('N')
SumStruct = dace.data.Structure(dict(sum=dace.data.Scalar(dace.float64)), name='SumStruct')

@dace.program
def struct_member_based_sum(A: dace.float64[N], B: SumStruct, C: dace.float64[N]):
tmp = 0.0
for i in range(N):
tmp += A[i]
B.sum = tmp
for i in range(N):
C[i] = A[i] + B.sum

N = 40
A = np.random.rand(N)
C = np.random.rand(N)
C_val = np.zeros((N,))
sum = 0
for i in range(N):
sum += A[i]
for i in range(N):
C_val[i] = A[i] + sum

outB = SumStruct.dtype._typeclass.as_ctypes()(sum=0)

func = struct_member_based_sum.compile()
func(A=A, B=outB, C=C, N=N)

# C is used for numerical validation because the Python frontend does not allow directly writing to scalars as an
# output (B.sum). Using them as intermediate values is possible though.
assert np.allclose(C, C_val)


def test_local_structure():

M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
Expand Down Expand Up @@ -227,6 +261,7 @@ def csr_to_dense_python(A: CSR, B: dace.float32[M, N]):
if __name__ == '__main__':
test_read_structure()
test_write_structure()
test_write_structure_scalar()
test_local_structure()
test_rgf()
# test_read_structure_gpu()

0 comments on commit 3ad9f82

Please sign in to comment.