diff --git a/src/gt4py/next/embedded/function_field.py b/src/gt4py/next/embedded/function_field.py index 96f2086b5e..f6278fb944 100644 --- a/src/gt4py/next/embedded/function_field.py +++ b/src/gt4py/next/embedded/function_field.py @@ -247,7 +247,7 @@ def constant_field( def _compose_function_field_with_builtin(builtin_name: str) -> Callable: def _composed_function_field(field: FunctionField) -> FunctionField: - if builtin_name not in _BUILTINS: + if builtin_name not in _UNARY_BUILTINS: raise ValueError(f"Unsupported built-in function: {builtin_name}") if builtin_name in ["abs", "power", "gamma"]: @@ -261,13 +261,14 @@ def _composed_function_field(field: FunctionField) -> FunctionField: FunctionField.register_builtin_func(fbuiltins.broadcast, _broadcast) -_BUILTINS = fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES + fbuiltins.UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES + fbuiltins.UNARY_MATH_NUMBER_BUILTIN_NAMES +_UNARY_BUILTINS = fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES + fbuiltins.UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES + fbuiltins.UNARY_MATH_NUMBER_BUILTIN_NAMES -for builtin_name in _BUILTINS: - if builtin_name in ["abs", "power", "gamma"]: +for builtin_name in _UNARY_BUILTINS: + if builtin_name in ["abs", "gamma"]: continue FunctionField.register_builtin_func(getattr(fbuiltins, builtin_name), _compose_function_field_with_builtin(builtin_name)) +FunctionField.register_builtin_func(fbuiltins.abs, FunctionField.__abs__) def _get_params(func: Callable) -> str: """Pretty print callable parameters.""" diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index bfee55249f..8e93dff846 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -259,6 +259,16 @@ def _slice( getattr(fbuiltins, name), _make_unary_array_field_intrinsic_func(name, name) ) +_BaseNdArrayField.register_builtin_func( + fbuiltins.minimum, _make_binary_array_field_intrinsic_func("minimum", "minimum") # type: ignore[attr-defined] +) +_BaseNdArrayField.register_builtin_func( + fbuiltins.maximum, _make_binary_array_field_intrinsic_func("maximum", "maximum") # type: ignore[attr-defined] +) +_BaseNdArrayField.register_builtin_func( + fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined] +) + def _np_cp_setitem( self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], diff --git a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py index ea6f87f900..0fec4d14ee 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py @@ -11,7 +11,6 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import math import operator import numpy as np @@ -20,63 +19,23 @@ from gt4py.next import common from gt4py.next.common import Dimension, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, function_field as funcf -from gt4py.next import fbuiltins - from .test_common import get_infinite_domain, get_mixed_domain - - +from .test_nd_array_field import binary_logical_op, binary_arithmetic_op, binary_reverse_arithmetic_op I = Dimension("I") J = Dimension("J") K = Dimension("K") -def rfloordiv(x, y): - return operator.floordiv(y, x) - - -operators = [ - operator.add, - operator.sub, - operator.mul, - operator.truediv, - operator.floordiv, - operator.pow, - lambda x, y: operator.truediv(y, x), # Reverse true division - lambda x, y: operator.add(y, x), # Reverse addition - lambda x, y: operator.mul(y, x), # Reverse multiplication - lambda x, y: operator.sub(y, x), # Reverse subtraction - lambda x, y: operator.floordiv(y, x), # Reverse floor division -] - -logical_operators = [ - operator.xor, - operator.and_, - operator.or_, -] - - -@pytest.mark.parametrize( - "op_func, expected_result", - [ - (operator.add, 10 + 20), - (operator.sub, 10 - 20), - (operator.mul, 10 * 20), - (operator.truediv, 10 / 20), - (operator.floordiv, 10 // 20), - (rfloordiv, 20 // 10), - (operator.pow, 10**20), - (lambda x, y: operator.truediv(y, x), 20 / 10), - (operator.add, 10 + 20), - (operator.mul, 10 * 20), - (lambda x, y: operator.sub(y, x), 20 - 10), - ], -) -def test_constant_field_no_domain(op_func, expected_result): +def test_constant_field_no_domain(binary_arithmetic_op, binary_reverse_arithmetic_op): cf1 = funcf.constant_field(10) cf2 = funcf.constant_field(20) - result = op_func(cf1, cf2) - assert result.func() == expected_result + + ops = [binary_arithmetic_op, binary_reverse_arithmetic_op] + + for op in ops: + result = op(cf1, cf2) + assert result.func() == op(10, 20) @pytest.fixture( @@ -161,11 +120,7 @@ def adder(i, j): return i + j -@pytest.mark.parametrize( - "op_func", - operators, -) -def test_function_field_broadcast(op_func): +def test_function_field_broadcast(binary_arithmetic_op, binary_reverse_arithmetic_op): func1 = lambda x, y: x + y func2 = lambda y: 2 * y @@ -175,14 +130,16 @@ def test_function_field_broadcast(op_func): ff1 = funcf.FunctionField(func1, domain1) ff2 = funcf.FunctionField(func2, domain2) - result = op_func(ff1, ff2) + ops = [binary_arithmetic_op, binary_reverse_arithmetic_op] - assert result.func(5, 10) == op_func(func1(5, 10), func2(10)) - assert isinstance(result.ndarray, np.ndarray) + for op in ops: + result = op(ff1, ff2) + assert result.func(5, 10) == op(func1(5, 10), func2(10)) + assert isinstance(result.ndarray, np.ndarray) -@pytest.mark.parametrize("op_func", logical_operators) -def test_function_field_logical_operators(op_func): + +def test_function_field_logical_operators(binary_logical_op): func1 = lambda x, y: x > 5 func2 = lambda y: y < 10 @@ -192,9 +149,9 @@ def test_function_field_logical_operators(op_func): ff1 = funcf.FunctionField(func1, domain1) ff2 = funcf.FunctionField(func2, domain2) - result = op_func(ff1, ff2) + result = binary_logical_op(ff1, ff2) - assert result.func(5, 10) == op_func(func1(5, 10), func2(10)) + assert result.func(5, 10) == binary_logical_op(func1(5, 10), func2(10)) assert isinstance(result.ndarray, np.ndarray) @@ -296,21 +253,6 @@ def test_function_field_infinite_range(get_infinite_domain, get_mixed_domain): ff.ndarray -@pytest.mark.parametrize("builtin_name", funcf._BUILTINS) -def test_function_field_builtins(function_field, builtin_name): - if builtin_name in ["abs", "power", "gamma"]: - pytest.skip(f"Skipping '{builtin_name}'") - - fbuiltin_func = getattr(fbuiltins, builtin_name) - - result = fbuiltin_func(function_field).func(1, 2) - - if math.isnan(result): - assert math.isnan(np.__getattribute__(builtin_name)(3)) - else: - assert result == np.__getattribute__(builtin_name)(3) - - def test_unary_logical_op_boolean(): boolean_func = lambda x: x % 2 != 0 field = funcf.FunctionField(boolean_func, common.Domain((I, UnitRange(1, 10)))) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index e18a2683e6..e50d66b41e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -12,9 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import itertools -import math import operator -from typing import Callable, Iterable import numpy as np import pytest @@ -22,13 +20,11 @@ from gt4py.next import Dimension, common from gt4py.next.common import Domain, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field +from gt4py.next.embedded import function_field as funcf from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins from tests.next_tests.unit_tests.test_common import IDim, JDim -from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data - - IDim = Dimension("IDim") JDim = Dimension("JDim") KDim = Dimension("KDim") @@ -69,47 +65,79 @@ def unary_arithmetic_op(request): def unary_logical_op(request): yield request.param +@pytest.fixture( + params=[ + lambda x, y: operator.truediv(y, x), # Reverse true division + lambda x, y: operator.add(y, x), # Reverse addition + lambda x, y: operator.mul(y, x), # Reverse multiplication + lambda x, y: operator.sub(y, x), # Reverse subtraction + lambda x, y: operator.floordiv(y, x), # Reverse floor division + ] +) +def binary_reverse_arithmetic_op(request): + yield request.param -def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None): + +def _make_base_ndarray_field(arr: np.ndarray, nd_array_implementation, *, dtype=None): if not dtype: dtype = nd_array_implementation.float32 return common.field( - nd_array_implementation.asarray(lst, dtype=dtype), - domain={common.Dimension("foo"): (0, len(lst))}, + nd_array_implementation.asarray(arr, dtype=dtype), + domain={common.Dimension("foo"): (0, len(arr))}, + ) + + +def _make_function_field(): + def adder(i, j): + return i + j + + return funcf.FunctionField( + adder, + domain=common.Domain( + dims=(IDim, JDim), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)) + ), ) +normal_dist = np.random.normal(3, 2.5, size=(10,)) -@pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) -def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementation): +@pytest.fixture( + params=[ + _make_base_ndarray_field(normal_dist, np), + _make_function_field() + ] +) +def all_field_types(request): + yield request.param + + +@pytest.mark.parametrize("builtin_name", funcf._UNARY_BUILTINS) +def test_unary_builtins_for_all_fields(all_field_types, builtin_name): if builtin_name == "gamma": # numpy has no gamma function pytest.xfail("TODO: implement gamma") - ref_impl: Callable = np.vectorize(math.gamma) - else: - ref_impl: Callable = getattr(np, builtin_name) - - expected = ref_impl(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) - - field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] - builtin = getattr(fbuiltins, builtin_name) - result = builtin(*field_inputs) + fbuiltin_func = getattr(fbuiltins, builtin_name) + result = fbuiltin_func(all_field_types).ndarray - assert np.allclose(result.ndarray, expected) + expected = getattr(np, builtin_name)(all_field_types.ndarray) + assert np.allclose(result, expected, equal_nan=True) -def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): +def test_binary_arithmetic_ops(binary_arithmetic_op, binary_reverse_arithmetic_op, nd_array_implementation): inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] inputs = [inp_a, inp_b] - expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) + ops = [binary_arithmetic_op, binary_reverse_arithmetic_op] - field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] + for op in ops: + expected = op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) - result = binary_arithmetic_op(*field_inputs) + field_inputs = [_make_base_ndarray_field(inp, nd_array_implementation) for inp in inputs] - assert np.allclose(result.ndarray, expected) + result = op(*field_inputs) + + assert np.allclose(result.ndarray, expected) def test_binary_logical_ops(binary_logical_op, nd_array_implementation): @@ -119,7 +147,7 @@ def test_binary_logical_ops(binary_logical_op, nd_array_implementation): expected = binary_logical_op(*[np.asarray(inp) for inp in inputs]) - field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs] + field_inputs = [_make_base_ndarray_field(inp, nd_array_implementation, dtype=bool) for inp in inputs] result = binary_logical_op(*field_inputs) @@ -134,7 +162,7 @@ def test_unary_logical_ops(unary_logical_op, nd_array_implementation): expected = unary_logical_op(np.asarray(inp)) - field_input = _make_field(inp, nd_array_implementation, dtype=bool) + field_input = _make_base_ndarray_field(inp, nd_array_implementation, dtype=bool) result = unary_logical_op(field_input) @@ -146,7 +174,7 @@ def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32)) - field_input = _make_field(inp, nd_array_implementation) + field_input = _make_base_ndarray_field(inp, nd_array_implementation) result = unary_arithmetic_op(field_input) @@ -197,8 +225,8 @@ def test_mixed_fields(product_nd_array_implementation): expected = np.asarray(inp_a) + np.asarray(inp_b) - field_inp_a = _make_field(inp_a, first_impl) - field_inp_b = _make_field(inp_b, second_impl) + field_inp_a = _make_base_ndarray_field(inp_a, first_impl) + field_inp_b = _make_base_ndarray_field(inp_b, second_impl) result = field_inp_a + field_inp_b assert np.allclose(result.ndarray, expected) @@ -215,9 +243,9 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field: expected = np.asarray(inp_a) * np.asarray(inp_b) + np.asarray(inp_c) - field_inp_a = _make_field(inp_a, np) - field_inp_b = _make_field(inp_b, np) - field_inp_c = _make_field(inp_c, np) + field_inp_a = _make_base_ndarray_field(inp_a, np) + field_inp_b = _make_base_ndarray_field(inp_b, np) + field_inp_c = _make_base_ndarray_field(inp_c, np) result = fma(field_inp_a, field_inp_b, field_inp_c) assert np.allclose(result.ndarray, expected)