Skip to content

Commit

Permalink
Implement numpy.any/all, more descriptive Python frontend errors (#1836)
Browse files Browse the repository at this point in the history
* Implement `numpy.any` and `numpy.all`
* More descriptive errors when slicing and using non-annotated callbacks
* Remove `dace.compiletime` arguments from symbolic analysis
  • Loading branch information
tbennun authored Jan 6, 2025
1 parent b36142b commit 7dc7957
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 11 deletions.
18 changes: 12 additions & 6 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4441,8 +4441,8 @@ def parse_target(t: Union[ast.Name, ast.Subscript]):
# Connect Python state
self._connect_pystate(tasklet, self.current_state, '__istate', '__ostate')

if return_type is None:
return []
if return_type is None: # Unknown but potentially used return value
return [dtypes.pyobject()]
else:
return return_names

Expand Down Expand Up @@ -4996,13 +4996,13 @@ def _visit_op(self, node: Union[ast.UnaryOp, ast.BinOp, ast.BoolOp], op1: ast.AS
operand2, op2type = None, None

# Type-check operands in order to provide a clear error message
if (isinstance(operand1, str) and operand1 in self.defined
and isinstance(self.defined[operand1].dtype, dtypes.pyobject)):
if (isinstance(operand1, dtypes.pyobject) or (isinstance(operand1, str) and operand1 in self.defined
and isinstance(self.defined[operand1].dtype, dtypes.pyobject))):
raise DaceSyntaxError(
self, op1, 'Trying to operate on a callback return value with an undefined type. '
f'Please add a type hint to "{operand1}" to enable using it within the program.')
if (isinstance(operand2, str) and operand2 in self.defined
and isinstance(self.defined[operand2].dtype, dtypes.pyobject)):
if (isinstance(operand2, dtypes.pyobject) or (isinstance(operand2, str) and operand2 in self.defined
and isinstance(self.defined[operand2].dtype, dtypes.pyobject))):
raise DaceSyntaxError(
self, op2, 'Trying to operate on a callback return value with an undefined type. '
f'Please add a type hint to "{operand2}" to enable using it within the program.')
Expand Down Expand Up @@ -5289,6 +5289,12 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False):
# Try to construct memlet from subscript
node.value = ast.Name(id=array)
defined = dace.sdfg.NestedDict({**self.sdfg.arrays, **self.defined})

if arrtype is data.Scalar and array in defined and isinstance(defined[array].dtype, dtypes.pyobject):
raise DaceSyntaxError(
self, node, f'Object "{array}" is defined as a callback return value and cannot be sliced. '
'Consider adding a type hint to the variable.')

expr: MemletExpr = ParseMemlet(self, defined, node, nslice)

if inference:
Expand Down
2 changes: 1 addition & 1 deletion dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _create_sdfg_args(self, sdfg: SDFG, args: Tuple[Any], kwargs: Dict[str, Any]
# Start with default arguments, then add other arguments
result = {**self.default_args}
# Reconstruct keyword arguments
result.update({aname: arg for aname, arg in zip(self.argnames, args)})
result.update({aname: arg for aname, arg in zip(self.argnames, args) if aname not in self.constant_args})
result.update(kwargs)

# Add closure arguments to the call
Expand Down
20 changes: 16 additions & 4 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,16 @@ def _sum_array(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, a: str):
return _reduce(pv, sdfg, state, "lambda x, y: x + y", a, axis=0, identity=0)


@oprepo.replaces('numpy.any')
def _any(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):
return _reduce(pv, sdfg, state, "lambda x, y: x or y", a, axis=axis, identity=0)


@oprepo.replaces('numpy.all')
def _all(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):
return _reduce(pv, sdfg, state, "lambda x, y: x and y", a, axis=axis, identity=0)


@oprepo.replaces('numpy.mean')
def _mean(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):

Expand All @@ -1042,26 +1052,28 @@ def _mean(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):

@oprepo.replaces('numpy.max')
@oprepo.replaces('numpy.amax')
def _max(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):
def _max(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None, initial=None):
initial = initial if initial is not None else dtypes.min_value(sdfg.arrays[a].dtype)
return _reduce(pv,
sdfg,
state,
"lambda x, y: max(x, y)",
a,
axis=axis,
identity=dtypes.min_value(sdfg.arrays[a].dtype))
identity=initial)


@oprepo.replaces('numpy.min')
@oprepo.replaces('numpy.amin')
def _min(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):
def _min(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None, initial=None):
initial = initial if initial is not None else dtypes.max_value(sdfg.arrays[a].dtype)
return _reduce(pv,
sdfg,
state,
"lambda x, y: min(x, y)",
a,
axis=axis,
identity=dtypes.max_value(sdfg.arrays[a].dtype))
identity=initial)


@oprepo.replaces('numpy.clip')
Expand Down
13 changes: 13 additions & 0 deletions tests/numpy/reductions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,16 @@ def test_degenerate_reduction_implicit(A: dace.float64[1, 20]):
return np.sum(A, axis=0)


@compare_numpy_output()
def test_any(A: dace.float64[20]):
return np.any(A > 0.8, axis=0)


@compare_numpy_output()
def test_all(A: dace.float64[20]):
return np.all(A > 0.8, axis=0)


if __name__ == '__main__':

# generated with cat tests/numpy/reductions_test.py | grep -oP '(?<=^def ).*(?=\()' | awk '{print $0 "()"}'
Expand Down Expand Up @@ -271,3 +281,6 @@ def test_degenerate_reduction_implicit(A: dace.float64[1, 20]):
test_scalar_reduction()
test_degenerate_reduction_explicit()
test_degenerate_reduction_implicit()

test_any()
test_all()
35 changes: 35 additions & 0 deletions tests/python_frontend/callback_autodetect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import time
from dace import config
from dace.frontend.python.common import DaceSyntaxError

N = dace.symbol('N')

Expand Down Expand Up @@ -906,6 +907,38 @@ def tester(a: dace.float64[20]):
assert np.allclose(aa, expected)


def test_disallowed_callback_in_condition():

@dace_inhibitor
def callbackfunc(arr):
return 42

@dace.program
def callback_in_condition(arr: dace.float64[20]):
if arr[0] < callbackfunc(arr):
return arr + 1
else:
return arr

with pytest.raises(DaceSyntaxError, match="Trying to operate on a callback"):
callback_in_condition.to_sdfg()


def test_disallowed_callback_slice():

@dace_inhibitor
def callbackfunc(arr):
return 42

@dace.program
def callback_in_condition(arr: dace.float64[20]):
a = callbackfunc(arr)
return arr + a[:20]

with pytest.raises(DaceSyntaxError, match="cannot be sliced"):
callback_in_condition.to_sdfg()


@pytest.mark.skip('Test requires GUI')
def test_matplotlib_with_compute():
"""
Expand Down Expand Up @@ -978,4 +1011,6 @@ def tester():
test_pyobject_return_tuple()
test_custom_generator()
test_custom_generator_with_break()
test_disallowed_callback_in_condition()
test_disallowed_callback_slice()
# test_matplotlib_with_compute()

0 comments on commit 7dc7957

Please sign in to comment.