Skip to content

Commit

Permalink
Introduce builtin test for all fields
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Sep 13, 2023
1 parent f248c69 commit 2a18ccc
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 113 deletions.
9 changes: 5 additions & 4 deletions src/gt4py/next/embedded/function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def constant_field(

def _compose_function_field_with_builtin(builtin_name: str) -> Callable:
def _composed_function_field(field: FunctionField) -> FunctionField:
if builtin_name not in _BUILTINS:
if builtin_name not in _UNARY_BUILTINS:
raise ValueError(f"Unsupported built-in function: {builtin_name}")

if builtin_name in ["abs", "power", "gamma"]:
Expand All @@ -261,13 +261,14 @@ def _composed_function_field(field: FunctionField) -> FunctionField:

FunctionField.register_builtin_func(fbuiltins.broadcast, _broadcast)

_BUILTINS = fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES + fbuiltins.UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES + fbuiltins.UNARY_MATH_NUMBER_BUILTIN_NAMES
_UNARY_BUILTINS = fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES + fbuiltins.UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES + fbuiltins.UNARY_MATH_NUMBER_BUILTIN_NAMES

for builtin_name in _BUILTINS:
if builtin_name in ["abs", "power", "gamma"]:
for builtin_name in _UNARY_BUILTINS:
if builtin_name in ["abs", "gamma"]:
continue
FunctionField.register_builtin_func(getattr(fbuiltins, builtin_name), _compose_function_field_with_builtin(builtin_name))

FunctionField.register_builtin_func(fbuiltins.abs, FunctionField.__abs__)

def _get_params(func: Callable) -> str:
"""Pretty print callable parameters."""
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,16 @@ def _slice(
getattr(fbuiltins, name), _make_unary_array_field_intrinsic_func(name, name)
)

_BaseNdArrayField.register_builtin_func(
fbuiltins.minimum, _make_binary_array_field_intrinsic_func("minimum", "minimum") # type: ignore[attr-defined]
)
_BaseNdArrayField.register_builtin_func(
fbuiltins.maximum, _make_binary_array_field_intrinsic_func("maximum", "maximum") # type: ignore[attr-defined]
)
_BaseNdArrayField.register_builtin_func(
fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined]
)


def _np_cp_setitem(
self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT],
Expand Down
94 changes: 18 additions & 76 deletions tests/next_tests/unit_tests/embedded_tests/test_function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
import math
import operator

import numpy as np
Expand All @@ -20,63 +19,23 @@
from gt4py.next import common
from gt4py.next.common import Dimension, UnitRange
from gt4py.next.embedded import exceptions as embedded_exceptions, function_field as funcf
from gt4py.next import fbuiltins

from .test_common import get_infinite_domain, get_mixed_domain


from .test_nd_array_field import binary_logical_op, binary_arithmetic_op, binary_reverse_arithmetic_op

I = Dimension("I")
J = Dimension("J")
K = Dimension("K")


def rfloordiv(x, y):
return operator.floordiv(y, x)


operators = [
operator.add,
operator.sub,
operator.mul,
operator.truediv,
operator.floordiv,
operator.pow,
lambda x, y: operator.truediv(y, x), # Reverse true division
lambda x, y: operator.add(y, x), # Reverse addition
lambda x, y: operator.mul(y, x), # Reverse multiplication
lambda x, y: operator.sub(y, x), # Reverse subtraction
lambda x, y: operator.floordiv(y, x), # Reverse floor division
]

logical_operators = [
operator.xor,
operator.and_,
operator.or_,
]


@pytest.mark.parametrize(
"op_func, expected_result",
[
(operator.add, 10 + 20),
(operator.sub, 10 - 20),
(operator.mul, 10 * 20),
(operator.truediv, 10 / 20),
(operator.floordiv, 10 // 20),
(rfloordiv, 20 // 10),
(operator.pow, 10**20),
(lambda x, y: operator.truediv(y, x), 20 / 10),
(operator.add, 10 + 20),
(operator.mul, 10 * 20),
(lambda x, y: operator.sub(y, x), 20 - 10),
],
)
def test_constant_field_no_domain(op_func, expected_result):
def test_constant_field_no_domain(binary_arithmetic_op, binary_reverse_arithmetic_op):
cf1 = funcf.constant_field(10)
cf2 = funcf.constant_field(20)
result = op_func(cf1, cf2)
assert result.func() == expected_result

ops = [binary_arithmetic_op, binary_reverse_arithmetic_op]

for op in ops:
result = op(cf1, cf2)
assert result.func() == op(10, 20)


@pytest.fixture(
Expand Down Expand Up @@ -161,11 +120,7 @@ def adder(i, j):
return i + j


@pytest.mark.parametrize(
"op_func",
operators,
)
def test_function_field_broadcast(op_func):
def test_function_field_broadcast(binary_arithmetic_op, binary_reverse_arithmetic_op):
func1 = lambda x, y: x + y
func2 = lambda y: 2 * y

Expand All @@ -175,14 +130,16 @@ def test_function_field_broadcast(op_func):
ff1 = funcf.FunctionField(func1, domain1)
ff2 = funcf.FunctionField(func2, domain2)

result = op_func(ff1, ff2)
ops = [binary_arithmetic_op, binary_reverse_arithmetic_op]

assert result.func(5, 10) == op_func(func1(5, 10), func2(10))
assert isinstance(result.ndarray, np.ndarray)
for op in ops:
result = op(ff1, ff2)

assert result.func(5, 10) == op(func1(5, 10), func2(10))
assert isinstance(result.ndarray, np.ndarray)

@pytest.mark.parametrize("op_func", logical_operators)
def test_function_field_logical_operators(op_func):

def test_function_field_logical_operators(binary_logical_op):
func1 = lambda x, y: x > 5
func2 = lambda y: y < 10

Expand All @@ -192,9 +149,9 @@ def test_function_field_logical_operators(op_func):
ff1 = funcf.FunctionField(func1, domain1)
ff2 = funcf.FunctionField(func2, domain2)

result = op_func(ff1, ff2)
result = binary_logical_op(ff1, ff2)

assert result.func(5, 10) == op_func(func1(5, 10), func2(10))
assert result.func(5, 10) == binary_logical_op(func1(5, 10), func2(10))
assert isinstance(result.ndarray, np.ndarray)


Expand Down Expand Up @@ -296,21 +253,6 @@ def test_function_field_infinite_range(get_infinite_domain, get_mixed_domain):
ff.ndarray


@pytest.mark.parametrize("builtin_name", funcf._BUILTINS)
def test_function_field_builtins(function_field, builtin_name):
if builtin_name in ["abs", "power", "gamma"]:
pytest.skip(f"Skipping '{builtin_name}'")

fbuiltin_func = getattr(fbuiltins, builtin_name)

result = fbuiltin_func(function_field).func(1, 2)

if math.isnan(result):
assert math.isnan(np.__getattribute__(builtin_name)(3))
else:
assert result == np.__getattribute__(builtin_name)(3)


def test_unary_logical_op_boolean():
boolean_func = lambda x: x % 2 != 0
field = funcf.FunctionField(boolean_func, common.Domain((I, UnitRange(1, 10))))
Expand Down
94 changes: 61 additions & 33 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,19 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import itertools
import math
import operator
from typing import Callable, Iterable

import numpy as np
import pytest

from gt4py.next import Dimension, common
from gt4py.next.common import Domain, UnitRange
from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field
from gt4py.next.embedded import function_field as funcf
from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice
from gt4py.next.ffront import fbuiltins
from tests.next_tests.unit_tests.test_common import IDim, JDim

from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data


IDim = Dimension("IDim")
JDim = Dimension("JDim")
KDim = Dimension("KDim")
Expand Down Expand Up @@ -69,47 +65,79 @@ def unary_arithmetic_op(request):
def unary_logical_op(request):
yield request.param

@pytest.fixture(
params=[
lambda x, y: operator.truediv(y, x), # Reverse true division
lambda x, y: operator.add(y, x), # Reverse addition
lambda x, y: operator.mul(y, x), # Reverse multiplication
lambda x, y: operator.sub(y, x), # Reverse subtraction
lambda x, y: operator.floordiv(y, x), # Reverse floor division
]
)
def binary_reverse_arithmetic_op(request):
yield request.param

def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None):

def _make_base_ndarray_field(arr: np.ndarray, nd_array_implementation, *, dtype=None):
if not dtype:
dtype = nd_array_implementation.float32
return common.field(
nd_array_implementation.asarray(lst, dtype=dtype),
domain={common.Dimension("foo"): (0, len(lst))},
nd_array_implementation.asarray(arr, dtype=dtype),
domain={common.Dimension("foo"): (0, len(arr))},
)


def _make_function_field():
def adder(i, j):
return i + j

return funcf.FunctionField(
adder,
domain=common.Domain(
dims=(IDim, JDim), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))
),
)

normal_dist = np.random.normal(3, 2.5, size=(10,))

@pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data())
def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementation):
@pytest.fixture(
params=[
_make_base_ndarray_field(normal_dist, np),
_make_function_field()
]
)
def all_field_types(request):
yield request.param


@pytest.mark.parametrize("builtin_name", funcf._UNARY_BUILTINS)
def test_unary_builtins_for_all_fields(all_field_types, builtin_name):
if builtin_name == "gamma":
# numpy has no gamma function
pytest.xfail("TODO: implement gamma")
ref_impl: Callable = np.vectorize(math.gamma)
else:
ref_impl: Callable = getattr(np, builtin_name)

expected = ref_impl(*[np.asarray(inp, dtype=np.float32) for inp in inputs])

field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs]

builtin = getattr(fbuiltins, builtin_name)
result = builtin(*field_inputs)
fbuiltin_func = getattr(fbuiltins, builtin_name)
result = fbuiltin_func(all_field_types).ndarray

assert np.allclose(result.ndarray, expected)
expected = getattr(np, builtin_name)(all_field_types.ndarray)
assert np.allclose(result, expected, equal_nan=True)


def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation):
def test_binary_arithmetic_ops(binary_arithmetic_op, binary_reverse_arithmetic_op, nd_array_implementation):
inp_a = [-1.0, 4.2, 42]
inp_b = [2.0, 3.0, -3.0]
inputs = [inp_a, inp_b]

expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs])
ops = [binary_arithmetic_op, binary_reverse_arithmetic_op]

field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs]
for op in ops:
expected = op(*[np.asarray(inp, dtype=np.float32) for inp in inputs])

result = binary_arithmetic_op(*field_inputs)
field_inputs = [_make_base_ndarray_field(inp, nd_array_implementation) for inp in inputs]

assert np.allclose(result.ndarray, expected)
result = op(*field_inputs)

assert np.allclose(result.ndarray, expected)


def test_binary_logical_ops(binary_logical_op, nd_array_implementation):
Expand All @@ -119,7 +147,7 @@ def test_binary_logical_ops(binary_logical_op, nd_array_implementation):

expected = binary_logical_op(*[np.asarray(inp) for inp in inputs])

field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs]
field_inputs = [_make_base_ndarray_field(inp, nd_array_implementation, dtype=bool) for inp in inputs]

result = binary_logical_op(*field_inputs)

Expand All @@ -134,7 +162,7 @@ def test_unary_logical_ops(unary_logical_op, nd_array_implementation):

expected = unary_logical_op(np.asarray(inp))

field_input = _make_field(inp, nd_array_implementation, dtype=bool)
field_input = _make_base_ndarray_field(inp, nd_array_implementation, dtype=bool)

result = unary_logical_op(field_input)

Expand All @@ -146,7 +174,7 @@ def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation):

expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32))

field_input = _make_field(inp, nd_array_implementation)
field_input = _make_base_ndarray_field(inp, nd_array_implementation)

result = unary_arithmetic_op(field_input)

Expand Down Expand Up @@ -197,8 +225,8 @@ def test_mixed_fields(product_nd_array_implementation):

expected = np.asarray(inp_a) + np.asarray(inp_b)

field_inp_a = _make_field(inp_a, first_impl)
field_inp_b = _make_field(inp_b, second_impl)
field_inp_a = _make_base_ndarray_field(inp_a, first_impl)
field_inp_b = _make_base_ndarray_field(inp_b, second_impl)

result = field_inp_a + field_inp_b
assert np.allclose(result.ndarray, expected)
Expand All @@ -215,9 +243,9 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field:

expected = np.asarray(inp_a) * np.asarray(inp_b) + np.asarray(inp_c)

field_inp_a = _make_field(inp_a, np)
field_inp_b = _make_field(inp_b, np)
field_inp_c = _make_field(inp_c, np)
field_inp_a = _make_base_ndarray_field(inp_a, np)
field_inp_b = _make_base_ndarray_field(inp_b, np)
field_inp_c = _make_base_ndarray_field(inp_c, np)

result = fma(field_inp_a, field_inp_b, field_inp_c)
assert np.allclose(result.ndarray, expected)
Expand Down

0 comments on commit 2a18ccc

Please sign in to comment.