diff --git a/.gitpod.yml b/.gitpod.yml index 802d87796a..c710caae60 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -26,22 +26,3 @@ vscode: - ms-toolsai.jupyter-keymap - ms-toolsai.jupyter-renderers - genuitecllc.codetogether - -github: - prebuilds: - # enable for the master/default branch (defaults to true) - master: true - # enable for all branches in this repo (defaults to false) - branches: false - # enable for pull requests coming from this repo (defaults to true) - pullRequests: true - # enable for pull requests coming from forks (defaults to false) - pullRequestsFromForks: true - # add a check to pull requests (defaults to true) - addCheck: true - # add a "Review in Gitpod" button as a comment to pull requests (defaults to false) - addComment: false - # add a "Review in Gitpod" button to the pull request's description (defaults to false) - addBadge: false - # add a label once the prebuild is ready to pull requests (defaults to false) - addLabel: false diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 3e1fe52f31..87f780b7e2 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -401,6 +401,12 @@ def __and__(self, other: Domain) -> Domain: def __str__(self) -> str: return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" + def is_finite(self) -> bool: + for _, rng in self: + if Infinity.positive() in (abs(rng.start), abs(rng.stop)): + return False + return True + def dim_index(self, dim: Dimension) -> Optional[int]: return self.dims.index(dim) if dim in self.dims else None diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 87e0800a10..b7813c9268 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -11,13 +11,14 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - from __future__ import annotations import functools import itertools import operator +import numpy as np + from gt4py.eve.extended_typing import Any, Optional, Sequence, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions @@ -42,9 +43,7 @@ def _relative_sub_domain( expanded = _expand_ellipsis(index, len(domain)) if len(domain) < len(expanded): - raise IndexError( - f"Can not access dimension with index {index} of 'Field' with {len(domain)} dimensions." - ) + raise embedded_exceptions.IndexOutOfBounds(domain=domain, indices=index) expanded += (slice(None),) * (len(domain) - len(expanded)) for (dim, rng), idx in zip(domain, expanded, strict=True): if isinstance(idx, slice): @@ -71,8 +70,12 @@ def _absolute_sub_domain( domain: common.Domain, index: common.AbsoluteIndexSequence ) -> common.Domain: named_ranges: list[common.NamedRange] = [] + + if len(domain) < len(index): + raise embedded_exceptions.IndexOutOfBounds(domain=domain, indices=index) + for i, (dim, rng) in enumerate(domain): - if (pos := _find_index_of_dim(dim, index)) is not None: + if (pos := find_index_of_dim(dim, index)) is not None: named_idx = index[pos] idx = named_idx[1] if isinstance(idx, common.UnitRange): @@ -137,7 +140,7 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit return common.UnitRange(start, stop) -def _find_index_of_dim( +def find_index_of_dim( dim: common.Dimension, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], ) -> Optional[int]: @@ -145,3 +148,29 @@ def _find_index_of_dim( if dim == d: return i return None + + +def broadcast_domain( + field: common.Field, new_dimensions: tuple[common.Dimension, ...] +) -> Sequence[common.NamedRange]: + named_ranges = [] + for dim in new_dimensions: + if (pos := find_index_of_dim(dim, field.domain)) is not None: + named_ranges.append((dim, field.domain[pos][1])) + else: + named_ranges.append( + (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) + ) + return named_ranges + + +def _compute_domain_slice( + field: common.Field, new_dimensions: tuple[common.Dimension, ...] +) -> Sequence[slice | None]: + domain_slice: list[slice | None] = [] + for dim in new_dimensions: + if find_index_of_dim(dim, field.domain) is not None: + domain_slice.append(slice(None)) + else: + domain_slice.append(np.newaxis) + return domain_slice diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index 393123db36..c9d93cb5dc 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from typing import Optional + from gt4py.next import common from gt4py.next.errors import exceptions as gt4py_exceptions @@ -19,20 +21,52 @@ class IndexOutOfBounds(gt4py_exceptions.GT4PyError): domain: common.Domain indices: common.AnyIndexSpec - index: common.AnyIndexElement - dim: common.Dimension + index: Optional[common.AnyIndexElement] + dim: Optional[common.Dimension] def __init__( self, domain: common.Domain, indices: common.AnyIndexSpec, - index: common.AnyIndexElement, - dim: common.Dimension, + index: Optional[common.AnyIndexElement] = None, + dim: Optional[common.Dimension] = None, ): - super().__init__( - f"Out of bounds: slicing {domain} with index `{indices}`, `{index}` is out of bounds in dimension `{dim}`." - ) + msg = f"Out of bounds: slicing {domain} with index `{indices}`." + if index is not None and dim is not None: + msg += f" `{index}` is out of bounds in dimension `{dim}`." + + super().__init__(msg) self.domain = domain self.indices = indices self.index = index self.dim = dim + + +class EmptyDomainIndexError(gt4py_exceptions.GT4PyError): + cls_name: str + + def __init__(self, cls_name: str): + super().__init__(f"Error in `{cls_name}`: Cannot index `{cls_name}` with an empty domain.") + self.cls_name = cls_name + + +class FunctionFieldError(gt4py_exceptions.GT4PyError): + cls_name: str + msg: str + + def __init__(self, cls_name: str, msg: str): + super().__init__(f"Error in `{cls_name}`: {msg}.") + self.cls_name = cls_name + self.msg = msg + + +class InfiniteRangeNdarrayError(gt4py_exceptions.GT4PyError): + cls_name: str + domain: common.Domain + + def __init__(self, cls_name: str, domain: common.Domain): + super().__init__( + f"Error in `{cls_name}`: Cannot construct an ndarray with an infinite range in domain: `{domain}`." + ) + self.cls_name = cls_name + self.domain = domain diff --git a/src/gt4py/next/embedded/function_field.py b/src/gt4py/next/embedded/function_field.py new file mode 100644 index 0000000000..1c942c5956 --- /dev/null +++ b/src/gt4py/next/embedded/function_field.py @@ -0,0 +1,314 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +import dataclasses +import inspect +import operator +from typing import Any, Callable, TypeGuard, overload + +import numpy as np + +from gt4py._core import definitions as core_defs +from gt4py.next import common +from gt4py.next.embedded import ( + common as embedded_common, + exceptions as embedded_exceptions, + nd_array_field as nd, +) +from gt4py.next.ffront import fbuiltins + + +@dataclasses.dataclass(frozen=True) +class FunctionField(common.FieldBuiltinFuncRegistry, common.Field[common.DimsT, core_defs.ScalarT]): + """A `FunctionField` represents a field of values generated by a callable function over a specified domain. + + The function supplied to the `func` parameter will be used to create the ndarray when accessing + the `ndarray` property. The result of calling `ndarray` will be the same as using + `np.fromfunction` with the provided function. + + Args: + func (Callable): The callable function that generates field values. + domain (common.Domain, optional): The domain over which the function is defined. + Defaults to an empty domain. + + Examples: + Create a FunctionField and compute its ndarray: + + >>> import numpy as np + >>> from gt4py.next import common + >>> from gt4py.next.embedded.function_field import FunctionField + >>> I = common.Dimension("I") + >>> domain = common.Domain((I, common.UnitRange(0, 5))) + >>> func = lambda i: i ** 2 + >>> field = FunctionField(func, domain) + >>> ndarray = field.ndarray + >>> expected_ndarray = np.fromfunction(func, (5,)) + >>> np.array_equal(ndarray, expected_ndarray) + True + """ + + func: Callable + domain: common.Domain = common.Domain() + + def __post_init__(self): + if not callable(self.func): + raise embedded_exceptions.FunctionFieldError( + self.__class__.__name__, + f"Invalid first argument type: Expected a function but got {self.func}", + ) + + if __debug__: + try: + self._trigger_func() + except Exception: + params = _get_params(self.func) + raise embedded_exceptions.FunctionFieldError( + self.__class__.__name__, + f"Invariant violation: len(self.domain) ({len(self.domain)}) does not match the number of parameters of the provided function ({params})", + ) + + @property + def __gt_dims__(self) -> tuple[common.Dimension, ...]: + return self.domain.dims + + @property + def __gt_origin__(self) -> tuple[int, ...]: + return tuple(-r.start for _, r in self.domain) + + def _trigger_func(self): + # TODO what should happen when domain is empty? test that this is fine. + target_shape = tuple(1 for _ in range(len(self.domain))) + return np.fromfunction(self.func, target_shape) + + @property + def dtype(self) -> core_defs.DType[core_defs.ScalarT]: + return core_defs.dtype(self.ndarray.dtype.type) + + def restrict(self, index: common.AnyIndexSpec) -> FunctionField: + new_domain = embedded_common.sub_domain(self.domain, index) + return self.__class__(self.func, new_domain) + + __getitem__ = restrict + + def asnumpy(self) -> core_defs.NDArrayObject: + # handle case where we have a constant FunctionField where field.ndarray is a scalar + if ( + isinstance(self._trigger_func(), (int, float)) and not self.domain.is_finite() + ): # TODO cover all relevant types if this code path still makes sense + return np.full(tuple(1 for _ in self.domain.shape), self.func()) + + if not self.domain.is_finite(): + raise embedded_exceptions.InfiniteRangeNdarrayError( + self.__class__.__name__, self.domain + ) + return np.fromfunction(self.func, self.domain.shape) + + @property + def ndarray(self) -> core_defs.NDArrayObject: + return self.asnumpy() + + def _handle_function_field_op(self, other: FunctionField, op: Callable) -> FunctionField: + domain_intersection = self.domain & other.domain + broadcasted_self = _broadcast(self, domain_intersection.dims) + broadcasted_other = _broadcast(other, domain_intersection.dims) + return self.__class__( + _compose(op, broadcasted_self, broadcasted_other), + domain_intersection, + ) + + def _handle_scalar_op(self, other: FunctionField, op: Callable) -> FunctionField: + def new_func(*args): + return op(self.func(*args), other) + + return self.__class__( + new_func, self.domain + ) # skip invariant as we cannot deduce number of args + + @overload + def _binary_operation(self, op: Callable, other: core_defs.ScalarT) -> common.Field: + ... + + @overload + def _binary_operation(self, op: Callable, other: common.Field) -> common.Field: + ... + + def _binary_operation(self, op, other): + if isinstance(other, self.__class__): + return self._handle_function_field_op(other, op) + elif isinstance(other, (int, float)): + return self._handle_scalar_op(other, op) + else: + return op(other, self) + + def _unary_op(self, op: Callable) -> FunctionField: + return self.__class__(_compose(op, self), self.domain) + + def __ne__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return NotImplemented # TODO + + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.add, other) + + def __sub__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.sub, other) + + def __mul__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.mul, other) + + def __truediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.truediv, other) + + def __floordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.floordiv, other) + + def __mod__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.mod, other) + + __rmod__ = __mod__ + + def __pow__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.pow, other) + + def __lt__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.lt, other) + + def __le__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.le, other) + + def __gt__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.gt, other) + + def __ge__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.ge, other) + + def __and__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.and_, other) + + __rand__ = __and__ + + def __or__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.or_, other) + + __ror__ = __or__ + + def __xor__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.xor, other) + + __rxor__ = __xor__ + + def __radd__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y + x, other) + + def __rfloordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y // x, other) + + def __rmul__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y * x, other) + + def __rsub__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y - x, other) + + def __rtruediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y / x, other) + + def __pos__(self) -> common.Field: + return self._unary_op(operator.pos) + + def __neg__(self) -> common.Field: + return self._unary_op(operator.neg) + + def __abs__(self) -> common.Field: + return self._unary_op(abs) + + def __invert__(self) -> common.Field: + if self.dtype == core_defs.BoolDType(): + return self._unary_op(operator.invert) + raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") + + def __call__(self, *args, **kwargs) -> common.Field: + return self.func(*args, **kwargs) + + def remap(self, *args, **kwargs) -> common.Field: + raise NotImplementedError("Method remap not implemented") + + +def _compose(operation: Callable, *fields: FunctionField) -> Callable: + return lambda *args: operation(*[f.func(*args) for f in fields]) + + +def _broadcast(field: FunctionField, dims: tuple[common.Dimension, ...]) -> FunctionField: + def broadcasted_func(*args: int | core_defs.NDArrayObject): + selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims] + return field.func(*selected_args) + + named_ranges = embedded_common.broadcast_domain(field, dims) + return FunctionField(broadcasted_func, common.Domain(*named_ranges)) + + +def _is_nd_array(other: Any) -> TypeGuard[nd._BaseNdArrayField]: + return isinstance(other, nd._BaseNdArrayField) + + +def constant_field( + value: core_defs.ScalarT, domain: common.Domain = common.Domain() +) -> common.Field: + return FunctionField(lambda *args: value, domain) + + +def _compose_function_field_with_builtin(builtin_name: str) -> Callable: + def _composed_function_field(field: FunctionField) -> FunctionField: + if builtin_name not in _UNARY_BUILTINS: + raise ValueError(f"Unsupported built-in function: {builtin_name}") + + if builtin_name in ["abs", "power", "gamma"]: + return field + + builtin_func = getattr(np, builtin_name) + + def new_func(*args): + return builtin_func(field.func(*args)) + + new_field: FunctionField = FunctionField(new_func, field.domain) + return new_field + + return _composed_function_field + + +FunctionField.register_builtin_func(fbuiltins.broadcast, _broadcast) + +_UNARY_BUILTINS = ( + fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES + + fbuiltins.UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES + + fbuiltins.UNARY_MATH_NUMBER_BUILTIN_NAMES +) + +for builtin_name in _UNARY_BUILTINS: + if builtin_name in ["abs", "gamma"]: + continue + FunctionField.register_builtin_func( + getattr(fbuiltins, builtin_name), _compose_function_field_with_builtin(builtin_name) + ) + +FunctionField.register_builtin_func(fbuiltins.abs, FunctionField.__abs__) # type: ignore[attr-defined] + + +def _get_params(func: Callable) -> str: + """Pretty print callable parameters.""" + signature = inspect.signature(func) + parameters = signature.parameters + param_strings = [f"{name}: {param}" for name, param in parameters.items()] + formatted_params = ", ".join(param_strings) + return formatted_params diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index fbfe64ac42..e253e6cefe 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -18,15 +18,15 @@ import functools from collections.abc import Callable, Sequence from types import ModuleType -from typing import ClassVar import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar +from gt4py.eve.extended_typing import Any, ClassVar, Never, Optional, ParamSpec, TypeAlias, TypeVar from gt4py.next import common from gt4py.next.embedded import common as embedded_common +from gt4py.next.embedded.common import _compute_domain_slice, broadcast_domain from gt4py.next.ffront import fbuiltins @@ -58,7 +58,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: if f.domain == domain_intersection: transformed.append(xp.asarray(f.ndarray)) else: - f_broadcasted = _broadcast(f, domain_intersection.dims) + f_broadcasted = fbuiltins.broadcast(f, domain_intersection.dims) f_slices = _get_slices_from_domain_slice( f_broadcasted.domain, domain_intersection ) @@ -558,17 +558,8 @@ def __setitem__( def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: - domain_slice: list[slice | None] = [] - named_ranges = [] - for dim in new_dimensions: - if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None: - domain_slice.append(slice(None)) - named_ranges.append((dim, field.domain[pos][1])) - else: - domain_slice.append(np.newaxis) - named_ranges.append( - (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) - ) + domain_slice = _compute_domain_slice(field, new_dimensions) + named_ranges = broadcast_domain(field, new_dimensions) return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) @@ -613,7 +604,7 @@ def _get_slices_from_domain_slice( slice_indices: list[slice | common.IntIndex] = [] for pos_old, (dim, _) in enumerate(domain): - if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None: + if (pos := embedded_common.find_index_of_dim(dim, domain_slice)) is not None: index_or_range = domain_slice[pos][1] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: diff --git a/tests/next_tests/unit_tests/embedded_tests/__init__.py b/tests/next_tests/unit_tests/embedded_tests/__init__.py new file mode 100644 index 0000000000..6c43e2f12a --- /dev/null +++ b/tests/next_tests/unit_tests/embedded_tests/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index de511fdabb..7ea207acbd 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -124,6 +124,11 @@ def test_slice_range(rng, slce, expected): (slice(1, 2), slice(1, 2), Ellipsis), [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], ), + ([], Ellipsis, []), + ([], slice(None), IndexError), + ([], 0, IndexError), + ([], (I, 0), IndexError), + # ([], (), []), # once we implement the array API standard ], ) def test_sub_domain(domain, index, expected): @@ -137,6 +142,33 @@ def test_sub_domain(domain, index, expected): assert result == expected +@pytest.fixture +def get_finite_domain(): + return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) + + +@pytest.fixture +def get_infinite_domain(): + return common.Domain((I, UnitRange.infinity()), (J, UnitRange.infinity())) + + +@pytest.fixture +def get_mixed_domain(): + return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange.infinity())) + + +def test_finite_domain_is_finite(get_finite_domain): + assert get_finite_domain.is_finite() + + +def test_infinite_domain_is_finite(get_infinite_domain): + assert not get_infinite_domain.is_finite() + + +def test_mixed_domain_is_finite(get_mixed_domain): + assert not get_mixed_domain.is_finite() + + def test_iterate_domain(): domain = common.domain({I: 2, J: 3}) ref = [] diff --git a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py new file mode 100644 index 0000000000..04ef220746 --- /dev/null +++ b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py @@ -0,0 +1,257 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import operator + +import numpy as np +import pytest + +from gt4py.next import common +from gt4py.next.common import Dimension, UnitRange +from gt4py.next.embedded import exceptions as embedded_exceptions, function_field as funcf + +from .test_common import get_infinite_domain, get_mixed_domain +from .test_nd_array_field import ( + binary_arithmetic_op, + binary_logical_op, + binary_reverse_arithmetic_op, +) + + +I = Dimension("I") +J = Dimension("J") +K = Dimension("K") + + +def test_constant_field_no_domain(binary_arithmetic_op, binary_reverse_arithmetic_op): + cf1 = funcf.constant_field(10) + cf2 = funcf.constant_field(20) + + ops = [binary_arithmetic_op, binary_reverse_arithmetic_op] + + for op in ops: + result = op(cf1, cf2) + assert result.func() == op(10, 20) + + +@pytest.fixture( + params=[((I, UnitRange(0, 10)),), common.Domain(dims=(I,), ranges=(UnitRange(0, 10),))] +) +def test_index(request): + return request.param + + +def test_constant_field_getitem_missing_domain(test_index): + cf = funcf.constant_field(10) + with pytest.raises(embedded_exceptions.IndexOutOfBounds): + cf[test_index] + + +def test_constant_field_getitem_missing_domain_ellipsis(test_index): + cf = funcf.constant_field(10) + cf[...].domain == cf.domain + + +@pytest.mark.parametrize( + "domain", + [ + common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))), + common.Domain( + dims=(I, J, K), ranges=(UnitRange(-6, -3), UnitRange(-5, 10), UnitRange(1, 2)) + ), + ], +) +def test_constant_field_ndarray(domain): + cf = funcf.constant_field(10, domain) + assert isinstance(cf.asnumpy(), int) + assert cf.asnumpy() == 10 + + +def test_constant_field_and_field_op(): + domain = common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) + field = common.field(np.ones((10, 10)), domain=domain) + cf = funcf.constant_field(10) + + result = cf + field + assert np.allclose(result.asnumpy(), 11) + assert result.domain == domain + + +binary_op_field_intersection_cases = [ + ( + common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))), + np.ones((10, 10)), + common.Domain(dims=(I, J), ranges=(UnitRange(3, 5), UnitRange(0, 5))), + 2.0, + (2, 5), + 3, + ), + ( + common.Domain(dims=(I, J), ranges=(UnitRange(-5, 2), UnitRange(3, 8))), + np.ones((7, 5)), + common.Domain(dims=(I,), ranges=(UnitRange(-5, 0),)), + 5, + (5, 5), + 6, + ), +] + + +def adder(i, j): + return i + j + + +def test_function_field_broadcast(binary_arithmetic_op, binary_reverse_arithmetic_op): + func1 = lambda x, y: x + y + func2 = lambda y: 2 * y + + domain1 = common.Domain(dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))) + domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),)) + + ff1 = funcf.FunctionField(func1, domain1) + ff2 = funcf.FunctionField(func2, domain2) + + ops = [binary_arithmetic_op, binary_reverse_arithmetic_op] + + for op in ops: + result = op(ff1, ff2) + + assert result.func(5, 10) == op(func1(5, 10), func2(10)) + assert isinstance(result.ndarray, np.ndarray) + + +def test_function_field_logical_operators(binary_logical_op): + func1 = lambda x, y: x > 5 + func2 = lambda y: y < 10 + + domain1 = common.Domain(dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))) + domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),)) + + ff1 = funcf.FunctionField(func1, domain1) + ff2 = funcf.FunctionField(func2, domain2) + + result = binary_logical_op(ff1, ff2) + + assert result.func(5, 10) == binary_logical_op(func1(5, 10), func2(10)) + assert isinstance(result.ndarray, np.ndarray) + + +@pytest.mark.parametrize( + "domain,expected_shape", + [ + (common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))), (10, 10)), + ], +) +def test_function_field_ndarray(domain, expected_shape): + ff = funcf.FunctionField(adder, domain) + assert ff.ndarray.shape == expected_shape + + ff_func = lambda *indices: adder(*indices) + expected_values = np.fromfunction(ff_func, ff.ndarray.shape) + assert np.allclose(ff.ndarray, expected_values) + + +@pytest.mark.parametrize( + "domain", + [ + common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))), + ], +) +def test_function_field_with_field(domain): + ff = funcf.FunctionField(adder, domain) + field = common.field(np.ones((10, 10)), domain=domain) + + result = ff + field + ff_func = lambda *indices: adder(*indices) + 1 + expected_values = np.fromfunction(ff_func, result.ndarray.shape) + + assert result.ndarray.shape == (10, 10) + assert np.allclose(result.ndarray, expected_values) + + +def test_function_field_function_field_op(): + res = funcf.FunctionField( + lambda x, y: x + 42 * y, + domain=common.Domain( + dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)) + ), + ) + funcf.FunctionField( + lambda y: 2 * y, domain=common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),)) + ) + + assert res.func(1, 2) == 89 + + +@pytest.fixture +def function_field(): + return funcf.FunctionField( + adder, + domain=common.Domain( + dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)) + ), + ) + + +def test_function_field_unary(function_field): + pos_result = +function_field + assert pos_result.func(1, 2) == 3 + + neg_result = -function_field + assert neg_result.func(1, 2) == -3 + + abs_result = abs(function_field) + assert abs_result.func(1, 2) == 3 + + +def test_function_field_scalar_op(function_field): + new = function_field * 5.0 + assert new.func(1, 2) == 15 + + +@pytest.mark.parametrize("func", ["foo", 1.0, 1]) +def test_function_field_invalid_func(func): + with pytest.raises(embedded_exceptions.FunctionFieldError, match="Invalid first argument type"): + funcf.FunctionField(func) + + +@pytest.mark.parametrize( + "domain", + [ + common.Domain(), + common.Domain(*((I, UnitRange(1, 10)), (J, UnitRange(5, 10)))), + ], +) +def test_function_field_invalid_invariant(domain): + with pytest.raises(embedded_exceptions.FunctionFieldError, match="Invariant violation"): + funcf.FunctionField(lambda *args, x: x, domain) + + +def test_function_field_infinite_range(get_infinite_domain, get_mixed_domain): + domains = [get_infinite_domain, get_mixed_domain] + for d in domains: + with pytest.raises(embedded_exceptions.InfiniteRangeNdarrayError): + ff = funcf.FunctionField(adder, d) + ff.ndarray + + +def test_unary_logical_op_boolean(): + boolean_func = lambda x: x % 2 != 0 + field = funcf.FunctionField(boolean_func, common.Domain((I, UnitRange(1, 10)))) + assert np.allclose(~field.ndarray, np.invert(np.fromfunction(boolean_func, (9,)))) + + +def test_unary_logical_op_scalar(): + scalar_func = lambda x: x % 2 + field = funcf.FunctionField(scalar_func, common.Domain((I, UnitRange(1, 10)))) + with pytest.raises(NotImplementedError): + ~field diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 1a38e5245e..67341281eb 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -11,23 +11,22 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import dataclasses import itertools -import math import operator -from typing import Callable, Iterable import numpy as np import pytest -from gt4py.next import common, embedded +from gt4py.next import Dimension, common, embedded from gt4py.next.common import Dimension, Domain, UnitRange -from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field +from gt4py.next.embedded import ( + exceptions as embedded_exceptions, + function_field as funcf, + nd_array_field, +) from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins -from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data - IDim = Dimension("IDim") JDim = Dimension("JDim") @@ -70,10 +69,23 @@ def unary_logical_op(request): yield request.param -def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=None): +@pytest.fixture( + params=[ + lambda x, y: operator.truediv(y, x), # Reverse true division + lambda x, y: operator.add(y, x), # Reverse addition + lambda x, y: operator.mul(y, x), # Reverse multiplication + lambda x, y: operator.sub(y, x), # Reverse subtraction + lambda x, y: operator.floordiv(y, x), # Reverse floor division + ] +) +def binary_reverse_arithmetic_op(request): + yield request.param + + +def _make_base_ndarray_field(arr: np.ndarray, nd_array_implementation, *, domain=None, dtype=None): if not dtype: dtype = nd_array_implementation.float32 - buffer = nd_array_implementation.asarray(lst, dtype=dtype) + buffer = nd_array_implementation.asarray(arr, dtype=dtype) if domain is None: domain = tuple( (common.Dimension(f"D{i}"), common.UnitRange(0, s)) for i, s in enumerate(buffer.shape) @@ -84,23 +96,37 @@ def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=No ) -@pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) -def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementation): +def _make_function_field(): + def adder(i, j): + return i + j + + return funcf.FunctionField( + adder, + domain=common.Domain( + dims=(IDim, JDim), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)) + ), + ) + + +normal_dist = np.random.normal(3, 2.5, size=(10,)) + + +@pytest.fixture(params=[_make_base_ndarray_field(normal_dist, np), _make_function_field()]) +def all_field_types(request): + yield request.param + + +@pytest.mark.parametrize("builtin_name", funcf._UNARY_BUILTINS) +def test_unary_builtins_for_all_fields(all_field_types, builtin_name): if builtin_name == "gamma": # numpy has no gamma function pytest.xfail("TODO: implement gamma") - ref_impl: Callable = np.vectorize(math.gamma) - else: - ref_impl: Callable = getattr(np, builtin_name) - - expected = ref_impl(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) - - field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] - builtin = getattr(fbuiltins, builtin_name) - result = builtin(*field_inputs) + fbuiltin_func = getattr(fbuiltins, builtin_name) + result = fbuiltin_func(all_field_types).asnumpy() - assert np.allclose(result.ndarray, expected) + expected = getattr(np, builtin_name)(all_field_types.asnumpy()) + assert np.allclose(result, expected, equal_nan=True) def test_where_builtin(nd_array_implementation): @@ -108,7 +134,9 @@ def test_where_builtin(nd_array_implementation): true_ = np.asarray([1.0, 2.0], dtype=np.float32) false_ = np.asarray([3.0, 4.0], dtype=np.float32) - field_inputs = [_make_field(inp, nd_array_implementation) for inp in [cond, true_, false_]] + field_inputs = [ + _make_base_ndarray_field(inp, nd_array_implementation) for inp in [cond, true_, false_] + ] expected = np.where(cond, true_, false_) result = fbuiltins.where(*field_inputs) @@ -148,27 +176,36 @@ def test_where_builtin_with_tuple(nd_array_implementation): expected0 = np.where(cond, true0, false0) expected1 = np.where(cond, true1, false1) - cond_field = _make_field(cond, nd_array_implementation, dtype=bool) - field_true = tuple(_make_field(inp, nd_array_implementation) for inp in [true0, true1]) - field_false = tuple(_make_field(inp, nd_array_implementation) for inp in [false0, false1]) + cond_field = _make_base_ndarray_field(cond, nd_array_implementation, dtype=bool) + field_true = tuple( + _make_base_ndarray_field(inp, nd_array_implementation) for inp in [true0, true1] + ) + field_false = tuple( + _make_base_ndarray_field(inp, nd_array_implementation) for inp in [false0, false1] + ) result = fbuiltins.where(cond_field, field_true, field_false) assert np.allclose(result[0].ndarray, expected0) assert np.allclose(result[1].ndarray, expected1) -def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): +def test_binary_arithmetic_ops( + binary_arithmetic_op, binary_reverse_arithmetic_op, nd_array_implementation +): inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] inputs = [inp_a, inp_b] - expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) + ops = [binary_arithmetic_op, binary_reverse_arithmetic_op] - field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] + for op in ops: + expected = op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) - result = binary_arithmetic_op(*field_inputs) + field_inputs = [_make_base_ndarray_field(inp, nd_array_implementation) for inp in inputs] - assert np.allclose(result.ndarray, expected) + result = op(*field_inputs) + + assert np.allclose(result.ndarray, expected) def test_binary_logical_ops(binary_logical_op, nd_array_implementation): @@ -178,7 +215,9 @@ def test_binary_logical_ops(binary_logical_op, nd_array_implementation): expected = binary_logical_op(*[np.asarray(inp) for inp in inputs]) - field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs] + field_inputs = [ + _make_base_ndarray_field(inp, nd_array_implementation, dtype=bool) for inp in inputs + ] result = binary_logical_op(*field_inputs) @@ -193,7 +232,7 @@ def test_unary_logical_ops(unary_logical_op, nd_array_implementation): expected = unary_logical_op(np.asarray(inp)) - field_input = _make_field(inp, nd_array_implementation, dtype=bool) + field_input = _make_base_ndarray_field(inp, nd_array_implementation, dtype=bool) result = unary_logical_op(field_input) @@ -205,7 +244,7 @@ def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32)) - field_input = _make_field(inp, nd_array_implementation) + field_input = _make_base_ndarray_field(inp, nd_array_implementation) result = unary_arithmetic_op(field_input) @@ -256,8 +295,8 @@ def test_mixed_fields(product_nd_array_implementation): expected = np.asarray(inp_a) + np.asarray(inp_b) - field_inp_a = _make_field(inp_a, first_impl) - field_inp_b = _make_field(inp_b, second_impl) + field_inp_a = _make_base_ndarray_field(inp_a, first_impl) + field_inp_b = _make_base_ndarray_field(inp_b, second_impl) result = field_inp_a + field_inp_b assert np.allclose(result.ndarray, expected) @@ -274,9 +313,9 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field: expected = np.asarray(inp_a) * np.asarray(inp_b) + np.asarray(inp_c) - field_inp_a = _make_field(inp_a, np) - field_inp_b = _make_field(inp_b, np) - field_inp_c = _make_field(inp_c, np) + field_inp_a = _make_base_ndarray_field(inp_a, np) + field_inp_b = _make_base_ndarray_field(inp_b, np) + field_inp_c = _make_base_ndarray_field(inp_c, np) result = fma(field_inp_a, field_inp_b, field_inp_c) assert np.allclose(result.ndarray, expected)