Skip to content

Commit

Permalink
Adds support for NumPy version 2 (#1613)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 authored Feb 6, 2025
1 parent 118c131 commit 5097d6f
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 17 deletions.
2 changes: 1 addition & 1 deletion dace/codegen/cppunparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def _Repr(self, t):

def _Num(self, t):
t_n = t.value if sys.version_info >= (3, 8) else t.n
repr_n = repr(t_n)
repr_n = str(t_n)
# For complex values, use ``dtype_to_typeclass``
if isinstance(t_n, complex):
dtype = dtypes.dtype_to_typeclass(complex)
Expand Down
8 changes: 8 additions & 0 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ def subscript_to_ast_slice_recursive(node):


class ExtUnparser(astunparse.Unparser):

def _Constant(self, t):
# NOTE: This is needed since NumPy 2.0 to avoid unparsing NumPy scalars as calls, e.g. `numpy.int32(1)`
if isinstance(t.value, numbers.Number):
self.write(str(t.value))
else:
super()._Constant(t)

def _Subscript(self, t):
self.dispatch(t.value)
self.write('[')
Expand Down
6 changes: 4 additions & 2 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import numpy
import sympy

numpy_version = numpy.lib.NumpyVersion(numpy.__version__)

# register replacements in oprepo
import dace.frontend.python.replacements
from dace.frontend.python.replacements import _sym_type, broadcast_to, broadcast_together
Expand Down Expand Up @@ -4918,9 +4920,9 @@ def visit_Num(self, node: NumConstant):
return node.n

def visit_Constant(self, node: ast.Constant):
if isinstance(node.value, bool):
if isinstance(node.value, bool) and numpy_version < '2.0.0':
return dace.bool_(node.value)
if isinstance(node.value, (int, float, complex)):
if isinstance(node.value, (int, float, complex)) and numpy_version < '2.0.0':
return dtypes.dtype_to_typeclass(type(node.value))(node.value)
if isinstance(node.value, (str, bytes)):
return StringLiteral(node.value)
Expand Down
20 changes: 17 additions & 3 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import numpy as np
import sympy as sp

numpy_version = np.lib.NumpyVersion(np.__version__)

Size = Union[int, dace.symbolic.symbol]
Shape = Sequence[Size]
if TYPE_CHECKING:
Expand Down Expand Up @@ -1680,22 +1682,28 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi

datatypes = []
dtypes_for_result = []
dtypes_for_result_np2 = []
for arg in arguments:
if isinstance(arg, (data.Array, data.Stream)):
datatypes.append(arg.dtype)
dtypes_for_result.append(arg.dtype.type)
dtypes_for_result_np2.append(arg.dtype.type)
elif isinstance(arg, data.Scalar):
datatypes.append(arg.dtype)
dtypes_for_result.append(_representative_num(arg.dtype))
dtypes_for_result_np2.append(arg.dtype.type)
elif isinstance(arg, (Number, np.bool_)):
datatypes.append(dtypes.dtype_to_typeclass(type(arg)))
dtypes_for_result.append(arg)
dtypes_for_result_np2.append(arg)
elif symbolic.issymbolic(arg):
datatypes.append(_sym_type(arg))
dtypes_for_result.append(_representative_num(_sym_type(arg)))
dtypes_for_result_np2.append(_sym_type(arg).type)
elif isinstance(arg, dtypes.typeclass):
datatypes.append(arg)
dtypes_for_result.append(_representative_num(arg))
dtypes_for_result_np2.append(arg.type)
else:
raise TypeError("Type {t} of argument {a} is not supported".format(t=type(arg), a=arg))

Expand Down Expand Up @@ -1728,8 +1736,11 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi
elif (operator in ('Fabs', 'Cbrt', 'Angles', 'SignBit', 'Spacing', 'Modf', 'Floor', 'Ceil', 'Trunc')
and coarse_types[0] == 3):
raise TypeError("ufunc '{}' not supported for complex input".format(operator))
elif operator in ('Ceil', 'Floor', 'Trunc') and coarse_types[0] < 2 and numpy_version < '2.1.0':
result_type = dace.float64
casting[0] = _cast_str(result_type)
elif (operator in ('Fabs', 'Rint', 'Exp', 'Log', 'Sqrt', 'Cbrt', 'Trigonometric', 'Angles', 'FpBoolean',
'Spacing', 'Modf', 'Floor', 'Ceil', 'Trunc') and coarse_types[0] < 2):
'Spacing', 'Modf') and coarse_types[0] < 2):
result_type = dace.float64
casting[0] = _cast_str(result_type)
elif operator in ('Frexp'):
Expand Down Expand Up @@ -1809,7 +1820,10 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi
result_type = dace.float64
# All other arithmetic operators and cases of the above operators
else:
result_type = _np_result_type(dtypes_for_result)
if numpy_version >= '2.0.0':
result_type = _np_result_type(dtypes_for_result_np2)
else:
result_type = _np_result_type(dtypes_for_result)

if dtype1 != result_type:
left_cast = _cast_str(result_type)
Expand Down Expand Up @@ -2701,7 +2715,7 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op
operator=None,
inputs=["__in1"],
outputs=["__out"],
code="__out = sign(__in1)",
code="__out = sign_numpy_2(__in1)" if numpy_version >= '2.0.0' else "__out = sign(__in1)",
reduce=None,
initial=np.sign.identity),
heaviside=dict(name="_numpy_heaviside_",
Expand Down
10 changes: 10 additions & 0 deletions dace/runtime/include/dace/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ template<typename T>
static DACE_CONSTEXPR DACE_HDFI std::complex<T> sign(const std::complex<T>& x) {
return (x.real() != 0) ? std::complex<T>(sign(x.real()), 0) : std::complex<T>(sign(x.imag()), 0);
}
// Numpy v2.0 or higher for complex inputs: sign(x) = x / abs(x)
template<typename T>
static DACE_CONSTEXPR DACE_HDFI T sign_numpy_2(const T& x) {
return T( (T(0) < x) - (x < T(0)) );
// return (x < 0) ? -1 : ( (x > 0) ? 1 : 0);
}
template<typename T>
static DACE_CONSTEXPR DACE_HDFI std::complex<T> sign_numpy_2(const std::complex<T>& x) {
return (x.real() != 0 && x.imag() != 0) ? x / std::abs(x) : std::complex<T>(0, 0);
}

// Computes the Heaviside step function
template<typename T>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
},
include_package_data=True,
install_requires=[
'numpy < 2.0', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply',
'numpy', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply',
'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill',
'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"', 'packaging'
] + cmake_requires,
Expand Down
18 changes: 9 additions & 9 deletions tests/numpy/ufunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ def test_ufunc_isfinite_c():
@compare_numpy_output(check_dtype=True)
def ufunc_isfinite_c(A: dace.complex64[10]):
A[0] = np.inf
A[1] = np.NaN
A[1] = np.nan
return np.isfinite(A)

args = dace.Config.get('compiler', 'cpu', 'args')
Expand All @@ -997,7 +997,7 @@ def test_ufunc_isfinite_f():
@compare_numpy_output(check_dtype=True)
def ufunc_isfinite_f(A: dace.float32[10]):
A[0] = np.inf
A[1] = np.NaN
A[1] = np.nan
return np.isfinite(A)

args = dace.Config.get('compiler', 'cpu', 'args')
Expand All @@ -1017,7 +1017,7 @@ def ufunc_isfinite_f(A: dace.float32[10]):
@compare_numpy_output(validation_func=lambda a: np.isfinite(a))
def test_ufunc_isfinite_u(A: dace.uint32[10]):
A[0] = np.inf
A[1] = np.NaN
A[1] = np.nan
return np.isfinite(A)


Expand All @@ -1026,7 +1026,7 @@ def test_ufunc_isinf_c():
@compare_numpy_output(check_dtype=True)
def ufunc_isinf_c(A: dace.complex64[10]):
A[0] = np.inf
A[1] = np.NaN
A[1] = np.nan
return np.isinf(A)

args = dace.Config.get('compiler', 'cpu', 'args')
Expand All @@ -1045,7 +1045,7 @@ def test_ufunc_isinf_f():
@compare_numpy_output(check_dtype=True)
def ufunc_isinf_f(A: dace.float32[10]):
A[0] = np.inf
A[1] = np.NaN
A[1] = np.nan
return np.isinf(A)

args = dace.Config.get('compiler', 'cpu', 'args')
Expand All @@ -1065,7 +1065,7 @@ def ufunc_isinf_f(A: dace.float32[10]):
@compare_numpy_output(validation_func=lambda a: np.isinf(a))
def test_ufunc_isinf_u(A: dace.uint32[10]):
A[0] = np.inf
A[1] = np.NaN
A[1] = np.nan
return np.isinf(A)


Expand All @@ -1074,7 +1074,7 @@ def test_ufunc_isnan_c():
@compare_numpy_output(check_dtype=True)
def ufunc_isnan_c(A: dace.complex64[10]):
A[0] = np.inf
A[1] = np.NaN
A[1] = np.nan
return np.isnan(A)

args = dace.Config.get('compiler', 'cpu', 'args')
Expand All @@ -1093,7 +1093,7 @@ def test_ufunc_isnan_f():
@compare_numpy_output(check_dtype=True)
def ufunc_isnan_f(A: dace.float32[10]):
A[0] = np.inf
A[1] = np.NaN
A[1] = np.nan
return np.isnan(A)

args = dace.Config.get('compiler', 'cpu', 'args')
Expand All @@ -1113,7 +1113,7 @@ def ufunc_isnan_f(A: dace.float32[10]):
@compare_numpy_output(validation_func=lambda a: np.isnan(a))
def test_ufunc_isnan_u(A: dace.uint32[10]):
A[0] = np.inf
A[1] = np.NaN
A[1] = np.nan
return np.isnan(A)


Expand Down
5 changes: 4 additions & 1 deletion tests/numpy/unop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def test_not():
B = np.zeros((5, 5), dtype=np.int64).astype(np.bool_)
regression = np.logical_not(A)
nottest(A, B)
assert np.alltrue(B == regression)
if np.lib.NumpyVersion(np.__version__) >= '2.0.0':
assert np.all(B == regression)
else:
assert np.alltrue(B == regression)


if __name__ == '__main__':
Expand Down

0 comments on commit 5097d6f

Please sign in to comment.