diff --git a/.gitpod.yml b/.gitpod.yml
index 802d87796a..c710caae60 100644
--- a/.gitpod.yml
+++ b/.gitpod.yml
@@ -26,22 +26,3 @@ vscode:
- ms-toolsai.jupyter-keymap
- ms-toolsai.jupyter-renderers
- genuitecllc.codetogether
-
-github:
- prebuilds:
- # enable for the master/default branch (defaults to true)
- master: true
- # enable for all branches in this repo (defaults to false)
- branches: false
- # enable for pull requests coming from this repo (defaults to true)
- pullRequests: true
- # enable for pull requests coming from forks (defaults to false)
- pullRequestsFromForks: true
- # add a check to pull requests (defaults to true)
- addCheck: true
- # add a "Review in Gitpod" button as a comment to pull requests (defaults to false)
- addComment: false
- # add a "Review in Gitpod" button to the pull request's description (defaults to false)
- addBadge: false
- # add a label once the prebuild is ready to pull requests (defaults to false)
- addLabel: false
diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py
index 3e1fe52f31..87f780b7e2 100644
--- a/src/gt4py/next/common.py
+++ b/src/gt4py/next/common.py
@@ -401,6 +401,12 @@ def __and__(self, other: Domain) -> Domain:
def __str__(self) -> str:
return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})"
+ def is_finite(self) -> bool:
+ for _, rng in self:
+ if Infinity.positive() in (abs(rng.start), abs(rng.stop)):
+ return False
+ return True
+
def dim_index(self, dim: Dimension) -> Optional[int]:
return self.dims.index(dim) if dim in self.dims else None
diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py
index 87e0800a10..b7813c9268 100644
--- a/src/gt4py/next/embedded/common.py
+++ b/src/gt4py/next/embedded/common.py
@@ -11,13 +11,14 @@
# distribution for a copy of the license or check .
#
# SPDX-License-Identifier: GPL-3.0-or-later
-
from __future__ import annotations
import functools
import itertools
import operator
+import numpy as np
+
from gt4py.eve.extended_typing import Any, Optional, Sequence, cast
from gt4py.next import common
from gt4py.next.embedded import exceptions as embedded_exceptions
@@ -42,9 +43,7 @@ def _relative_sub_domain(
expanded = _expand_ellipsis(index, len(domain))
if len(domain) < len(expanded):
- raise IndexError(
- f"Can not access dimension with index {index} of 'Field' with {len(domain)} dimensions."
- )
+ raise embedded_exceptions.IndexOutOfBounds(domain=domain, indices=index)
expanded += (slice(None),) * (len(domain) - len(expanded))
for (dim, rng), idx in zip(domain, expanded, strict=True):
if isinstance(idx, slice):
@@ -71,8 +70,12 @@ def _absolute_sub_domain(
domain: common.Domain, index: common.AbsoluteIndexSequence
) -> common.Domain:
named_ranges: list[common.NamedRange] = []
+
+ if len(domain) < len(index):
+ raise embedded_exceptions.IndexOutOfBounds(domain=domain, indices=index)
+
for i, (dim, rng) in enumerate(domain):
- if (pos := _find_index_of_dim(dim, index)) is not None:
+ if (pos := find_index_of_dim(dim, index)) is not None:
named_idx = index[pos]
idx = named_idx[1]
if isinstance(idx, common.UnitRange):
@@ -137,7 +140,7 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit
return common.UnitRange(start, stop)
-def _find_index_of_dim(
+def find_index_of_dim(
dim: common.Dimension,
domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any],
) -> Optional[int]:
@@ -145,3 +148,29 @@ def _find_index_of_dim(
if dim == d:
return i
return None
+
+
+def broadcast_domain(
+ field: common.Field, new_dimensions: tuple[common.Dimension, ...]
+) -> Sequence[common.NamedRange]:
+ named_ranges = []
+ for dim in new_dimensions:
+ if (pos := find_index_of_dim(dim, field.domain)) is not None:
+ named_ranges.append((dim, field.domain[pos][1]))
+ else:
+ named_ranges.append(
+ (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive()))
+ )
+ return named_ranges
+
+
+def _compute_domain_slice(
+ field: common.Field, new_dimensions: tuple[common.Dimension, ...]
+) -> Sequence[slice | None]:
+ domain_slice: list[slice | None] = []
+ for dim in new_dimensions:
+ if find_index_of_dim(dim, field.domain) is not None:
+ domain_slice.append(slice(None))
+ else:
+ domain_slice.append(np.newaxis)
+ return domain_slice
diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py
index 393123db36..c9d93cb5dc 100644
--- a/src/gt4py/next/embedded/exceptions.py
+++ b/src/gt4py/next/embedded/exceptions.py
@@ -12,6 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
+from typing import Optional
+
from gt4py.next import common
from gt4py.next.errors import exceptions as gt4py_exceptions
@@ -19,20 +21,52 @@
class IndexOutOfBounds(gt4py_exceptions.GT4PyError):
domain: common.Domain
indices: common.AnyIndexSpec
- index: common.AnyIndexElement
- dim: common.Dimension
+ index: Optional[common.AnyIndexElement]
+ dim: Optional[common.Dimension]
def __init__(
self,
domain: common.Domain,
indices: common.AnyIndexSpec,
- index: common.AnyIndexElement,
- dim: common.Dimension,
+ index: Optional[common.AnyIndexElement] = None,
+ dim: Optional[common.Dimension] = None,
):
- super().__init__(
- f"Out of bounds: slicing {domain} with index `{indices}`, `{index}` is out of bounds in dimension `{dim}`."
- )
+ msg = f"Out of bounds: slicing {domain} with index `{indices}`."
+ if index is not None and dim is not None:
+ msg += f" `{index}` is out of bounds in dimension `{dim}`."
+
+ super().__init__(msg)
self.domain = domain
self.indices = indices
self.index = index
self.dim = dim
+
+
+class EmptyDomainIndexError(gt4py_exceptions.GT4PyError):
+ cls_name: str
+
+ def __init__(self, cls_name: str):
+ super().__init__(f"Error in `{cls_name}`: Cannot index `{cls_name}` with an empty domain.")
+ self.cls_name = cls_name
+
+
+class FunctionFieldError(gt4py_exceptions.GT4PyError):
+ cls_name: str
+ msg: str
+
+ def __init__(self, cls_name: str, msg: str):
+ super().__init__(f"Error in `{cls_name}`: {msg}.")
+ self.cls_name = cls_name
+ self.msg = msg
+
+
+class InfiniteRangeNdarrayError(gt4py_exceptions.GT4PyError):
+ cls_name: str
+ domain: common.Domain
+
+ def __init__(self, cls_name: str, domain: common.Domain):
+ super().__init__(
+ f"Error in `{cls_name}`: Cannot construct an ndarray with an infinite range in domain: `{domain}`."
+ )
+ self.cls_name = cls_name
+ self.domain = domain
diff --git a/src/gt4py/next/embedded/function_field.py b/src/gt4py/next/embedded/function_field.py
new file mode 100644
index 0000000000..1c942c5956
--- /dev/null
+++ b/src/gt4py/next/embedded/function_field.py
@@ -0,0 +1,314 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from __future__ import annotations
+
+import dataclasses
+import inspect
+import operator
+from typing import Any, Callable, TypeGuard, overload
+
+import numpy as np
+
+from gt4py._core import definitions as core_defs
+from gt4py.next import common
+from gt4py.next.embedded import (
+ common as embedded_common,
+ exceptions as embedded_exceptions,
+ nd_array_field as nd,
+)
+from gt4py.next.ffront import fbuiltins
+
+
+@dataclasses.dataclass(frozen=True)
+class FunctionField(common.FieldBuiltinFuncRegistry, common.Field[common.DimsT, core_defs.ScalarT]):
+ """A `FunctionField` represents a field of values generated by a callable function over a specified domain.
+
+ The function supplied to the `func` parameter will be used to create the ndarray when accessing
+ the `ndarray` property. The result of calling `ndarray` will be the same as using
+ `np.fromfunction` with the provided function.
+
+ Args:
+ func (Callable): The callable function that generates field values.
+ domain (common.Domain, optional): The domain over which the function is defined.
+ Defaults to an empty domain.
+
+ Examples:
+ Create a FunctionField and compute its ndarray:
+
+ >>> import numpy as np
+ >>> from gt4py.next import common
+ >>> from gt4py.next.embedded.function_field import FunctionField
+ >>> I = common.Dimension("I")
+ >>> domain = common.Domain((I, common.UnitRange(0, 5)))
+ >>> func = lambda i: i ** 2
+ >>> field = FunctionField(func, domain)
+ >>> ndarray = field.ndarray
+ >>> expected_ndarray = np.fromfunction(func, (5,))
+ >>> np.array_equal(ndarray, expected_ndarray)
+ True
+ """
+
+ func: Callable
+ domain: common.Domain = common.Domain()
+
+ def __post_init__(self):
+ if not callable(self.func):
+ raise embedded_exceptions.FunctionFieldError(
+ self.__class__.__name__,
+ f"Invalid first argument type: Expected a function but got {self.func}",
+ )
+
+ if __debug__:
+ try:
+ self._trigger_func()
+ except Exception:
+ params = _get_params(self.func)
+ raise embedded_exceptions.FunctionFieldError(
+ self.__class__.__name__,
+ f"Invariant violation: len(self.domain) ({len(self.domain)}) does not match the number of parameters of the provided function ({params})",
+ )
+
+ @property
+ def __gt_dims__(self) -> tuple[common.Dimension, ...]:
+ return self.domain.dims
+
+ @property
+ def __gt_origin__(self) -> tuple[int, ...]:
+ return tuple(-r.start for _, r in self.domain)
+
+ def _trigger_func(self):
+ # TODO what should happen when domain is empty? test that this is fine.
+ target_shape = tuple(1 for _ in range(len(self.domain)))
+ return np.fromfunction(self.func, target_shape)
+
+ @property
+ def dtype(self) -> core_defs.DType[core_defs.ScalarT]:
+ return core_defs.dtype(self.ndarray.dtype.type)
+
+ def restrict(self, index: common.AnyIndexSpec) -> FunctionField:
+ new_domain = embedded_common.sub_domain(self.domain, index)
+ return self.__class__(self.func, new_domain)
+
+ __getitem__ = restrict
+
+ def asnumpy(self) -> core_defs.NDArrayObject:
+ # handle case where we have a constant FunctionField where field.ndarray is a scalar
+ if (
+ isinstance(self._trigger_func(), (int, float)) and not self.domain.is_finite()
+ ): # TODO cover all relevant types if this code path still makes sense
+ return np.full(tuple(1 for _ in self.domain.shape), self.func())
+
+ if not self.domain.is_finite():
+ raise embedded_exceptions.InfiniteRangeNdarrayError(
+ self.__class__.__name__, self.domain
+ )
+ return np.fromfunction(self.func, self.domain.shape)
+
+ @property
+ def ndarray(self) -> core_defs.NDArrayObject:
+ return self.asnumpy()
+
+ def _handle_function_field_op(self, other: FunctionField, op: Callable) -> FunctionField:
+ domain_intersection = self.domain & other.domain
+ broadcasted_self = _broadcast(self, domain_intersection.dims)
+ broadcasted_other = _broadcast(other, domain_intersection.dims)
+ return self.__class__(
+ _compose(op, broadcasted_self, broadcasted_other),
+ domain_intersection,
+ )
+
+ def _handle_scalar_op(self, other: FunctionField, op: Callable) -> FunctionField:
+ def new_func(*args):
+ return op(self.func(*args), other)
+
+ return self.__class__(
+ new_func, self.domain
+ ) # skip invariant as we cannot deduce number of args
+
+ @overload
+ def _binary_operation(self, op: Callable, other: core_defs.ScalarT) -> common.Field:
+ ...
+
+ @overload
+ def _binary_operation(self, op: Callable, other: common.Field) -> common.Field:
+ ...
+
+ def _binary_operation(self, op, other):
+ if isinstance(other, self.__class__):
+ return self._handle_function_field_op(other, op)
+ elif isinstance(other, (int, float)):
+ return self._handle_scalar_op(other, op)
+ else:
+ return op(other, self)
+
+ def _unary_op(self, op: Callable) -> FunctionField:
+ return self.__class__(_compose(op, self), self.domain)
+
+ def __ne__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return NotImplemented # TODO
+
+ def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.add, other)
+
+ def __sub__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.sub, other)
+
+ def __mul__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.mul, other)
+
+ def __truediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.truediv, other)
+
+ def __floordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.floordiv, other)
+
+ def __mod__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.mod, other)
+
+ __rmod__ = __mod__
+
+ def __pow__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.pow, other)
+
+ def __lt__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.lt, other)
+
+ def __le__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.le, other)
+
+ def __gt__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.gt, other)
+
+ def __ge__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.ge, other)
+
+ def __and__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.and_, other)
+
+ __rand__ = __and__
+
+ def __or__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.or_, other)
+
+ __ror__ = __or__
+
+ def __xor__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(operator.xor, other)
+
+ __rxor__ = __xor__
+
+ def __radd__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(lambda x, y: y + x, other)
+
+ def __rfloordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(lambda x, y: y // x, other)
+
+ def __rmul__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(lambda x, y: y * x, other)
+
+ def __rsub__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(lambda x, y: y - x, other)
+
+ def __rtruediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
+ return self._binary_operation(lambda x, y: y / x, other)
+
+ def __pos__(self) -> common.Field:
+ return self._unary_op(operator.pos)
+
+ def __neg__(self) -> common.Field:
+ return self._unary_op(operator.neg)
+
+ def __abs__(self) -> common.Field:
+ return self._unary_op(abs)
+
+ def __invert__(self) -> common.Field:
+ if self.dtype == core_defs.BoolDType():
+ return self._unary_op(operator.invert)
+ raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.")
+
+ def __call__(self, *args, **kwargs) -> common.Field:
+ return self.func(*args, **kwargs)
+
+ def remap(self, *args, **kwargs) -> common.Field:
+ raise NotImplementedError("Method remap not implemented")
+
+
+def _compose(operation: Callable, *fields: FunctionField) -> Callable:
+ return lambda *args: operation(*[f.func(*args) for f in fields])
+
+
+def _broadcast(field: FunctionField, dims: tuple[common.Dimension, ...]) -> FunctionField:
+ def broadcasted_func(*args: int | core_defs.NDArrayObject):
+ selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims]
+ return field.func(*selected_args)
+
+ named_ranges = embedded_common.broadcast_domain(field, dims)
+ return FunctionField(broadcasted_func, common.Domain(*named_ranges))
+
+
+def _is_nd_array(other: Any) -> TypeGuard[nd._BaseNdArrayField]:
+ return isinstance(other, nd._BaseNdArrayField)
+
+
+def constant_field(
+ value: core_defs.ScalarT, domain: common.Domain = common.Domain()
+) -> common.Field:
+ return FunctionField(lambda *args: value, domain)
+
+
+def _compose_function_field_with_builtin(builtin_name: str) -> Callable:
+ def _composed_function_field(field: FunctionField) -> FunctionField:
+ if builtin_name not in _UNARY_BUILTINS:
+ raise ValueError(f"Unsupported built-in function: {builtin_name}")
+
+ if builtin_name in ["abs", "power", "gamma"]:
+ return field
+
+ builtin_func = getattr(np, builtin_name)
+
+ def new_func(*args):
+ return builtin_func(field.func(*args))
+
+ new_field: FunctionField = FunctionField(new_func, field.domain)
+ return new_field
+
+ return _composed_function_field
+
+
+FunctionField.register_builtin_func(fbuiltins.broadcast, _broadcast)
+
+_UNARY_BUILTINS = (
+ fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES
+ + fbuiltins.UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES
+ + fbuiltins.UNARY_MATH_NUMBER_BUILTIN_NAMES
+)
+
+for builtin_name in _UNARY_BUILTINS:
+ if builtin_name in ["abs", "gamma"]:
+ continue
+ FunctionField.register_builtin_func(
+ getattr(fbuiltins, builtin_name), _compose_function_field_with_builtin(builtin_name)
+ )
+
+FunctionField.register_builtin_func(fbuiltins.abs, FunctionField.__abs__) # type: ignore[attr-defined]
+
+
+def _get_params(func: Callable) -> str:
+ """Pretty print callable parameters."""
+ signature = inspect.signature(func)
+ parameters = signature.parameters
+ param_strings = [f"{name}: {param}" for name, param in parameters.items()]
+ formatted_params = ", ".join(param_strings)
+ return formatted_params
diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py
index fbfe64ac42..e253e6cefe 100644
--- a/src/gt4py/next/embedded/nd_array_field.py
+++ b/src/gt4py/next/embedded/nd_array_field.py
@@ -18,15 +18,15 @@
import functools
from collections.abc import Callable, Sequence
from types import ModuleType
-from typing import ClassVar
import numpy as np
from numpy import typing as npt
from gt4py._core import definitions as core_defs
-from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar
+from gt4py.eve.extended_typing import Any, ClassVar, Never, Optional, ParamSpec, TypeAlias, TypeVar
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common
+from gt4py.next.embedded.common import _compute_domain_slice, broadcast_domain
from gt4py.next.ffront import fbuiltins
@@ -58,7 +58,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
if f.domain == domain_intersection:
transformed.append(xp.asarray(f.ndarray))
else:
- f_broadcasted = _broadcast(f, domain_intersection.dims)
+ f_broadcasted = fbuiltins.broadcast(f, domain_intersection.dims)
f_slices = _get_slices_from_domain_slice(
f_broadcasted.domain, domain_intersection
)
@@ -558,17 +558,8 @@ def __setitem__(
def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
- domain_slice: list[slice | None] = []
- named_ranges = []
- for dim in new_dimensions:
- if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None:
- domain_slice.append(slice(None))
- named_ranges.append((dim, field.domain[pos][1]))
- else:
- domain_slice.append(np.newaxis)
- named_ranges.append(
- (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive()))
- )
+ domain_slice = _compute_domain_slice(field, new_dimensions)
+ named_ranges = broadcast_domain(field, new_dimensions)
return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))
@@ -613,7 +604,7 @@ def _get_slices_from_domain_slice(
slice_indices: list[slice | common.IntIndex] = []
for pos_old, (dim, _) in enumerate(domain):
- if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None:
+ if (pos := embedded_common.find_index_of_dim(dim, domain_slice)) is not None:
index_or_range = domain_slice[pos][1]
slice_indices.append(_compute_slice(index_or_range, domain, pos_old))
else:
diff --git a/tests/next_tests/unit_tests/embedded_tests/__init__.py b/tests/next_tests/unit_tests/embedded_tests/__init__.py
new file mode 100644
index 0000000000..6c43e2f12a
--- /dev/null
+++ b/tests/next_tests/unit_tests/embedded_tests/__init__.py
@@ -0,0 +1,13 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py
index de511fdabb..7ea207acbd 100644
--- a/tests/next_tests/unit_tests/embedded_tests/test_common.py
+++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py
@@ -124,6 +124,11 @@ def test_slice_range(rng, slce, expected):
(slice(1, 2), slice(1, 2), Ellipsis),
[(I, (3, 4)), (J, (4, 5)), (K, (4, 7))],
),
+ ([], Ellipsis, []),
+ ([], slice(None), IndexError),
+ ([], 0, IndexError),
+ ([], (I, 0), IndexError),
+ # ([], (), []), # once we implement the array API standard
],
)
def test_sub_domain(domain, index, expected):
@@ -137,6 +142,33 @@ def test_sub_domain(domain, index, expected):
assert result == expected
+@pytest.fixture
+def get_finite_domain():
+ return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4)))
+
+
+@pytest.fixture
+def get_infinite_domain():
+ return common.Domain((I, UnitRange.infinity()), (J, UnitRange.infinity()))
+
+
+@pytest.fixture
+def get_mixed_domain():
+ return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange.infinity()))
+
+
+def test_finite_domain_is_finite(get_finite_domain):
+ assert get_finite_domain.is_finite()
+
+
+def test_infinite_domain_is_finite(get_infinite_domain):
+ assert not get_infinite_domain.is_finite()
+
+
+def test_mixed_domain_is_finite(get_mixed_domain):
+ assert not get_mixed_domain.is_finite()
+
+
def test_iterate_domain():
domain = common.domain({I: 2, J: 3})
ref = []
diff --git a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py
new file mode 100644
index 0000000000..04ef220746
--- /dev/null
+++ b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py
@@ -0,0 +1,257 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+import operator
+
+import numpy as np
+import pytest
+
+from gt4py.next import common
+from gt4py.next.common import Dimension, UnitRange
+from gt4py.next.embedded import exceptions as embedded_exceptions, function_field as funcf
+
+from .test_common import get_infinite_domain, get_mixed_domain
+from .test_nd_array_field import (
+ binary_arithmetic_op,
+ binary_logical_op,
+ binary_reverse_arithmetic_op,
+)
+
+
+I = Dimension("I")
+J = Dimension("J")
+K = Dimension("K")
+
+
+def test_constant_field_no_domain(binary_arithmetic_op, binary_reverse_arithmetic_op):
+ cf1 = funcf.constant_field(10)
+ cf2 = funcf.constant_field(20)
+
+ ops = [binary_arithmetic_op, binary_reverse_arithmetic_op]
+
+ for op in ops:
+ result = op(cf1, cf2)
+ assert result.func() == op(10, 20)
+
+
+@pytest.fixture(
+ params=[((I, UnitRange(0, 10)),), common.Domain(dims=(I,), ranges=(UnitRange(0, 10),))]
+)
+def test_index(request):
+ return request.param
+
+
+def test_constant_field_getitem_missing_domain(test_index):
+ cf = funcf.constant_field(10)
+ with pytest.raises(embedded_exceptions.IndexOutOfBounds):
+ cf[test_index]
+
+
+def test_constant_field_getitem_missing_domain_ellipsis(test_index):
+ cf = funcf.constant_field(10)
+ cf[...].domain == cf.domain
+
+
+@pytest.mark.parametrize(
+ "domain",
+ [
+ common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))),
+ common.Domain(
+ dims=(I, J, K), ranges=(UnitRange(-6, -3), UnitRange(-5, 10), UnitRange(1, 2))
+ ),
+ ],
+)
+def test_constant_field_ndarray(domain):
+ cf = funcf.constant_field(10, domain)
+ assert isinstance(cf.asnumpy(), int)
+ assert cf.asnumpy() == 10
+
+
+def test_constant_field_and_field_op():
+ domain = common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5)))
+ field = common.field(np.ones((10, 10)), domain=domain)
+ cf = funcf.constant_field(10)
+
+ result = cf + field
+ assert np.allclose(result.asnumpy(), 11)
+ assert result.domain == domain
+
+
+binary_op_field_intersection_cases = [
+ (
+ common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))),
+ np.ones((10, 10)),
+ common.Domain(dims=(I, J), ranges=(UnitRange(3, 5), UnitRange(0, 5))),
+ 2.0,
+ (2, 5),
+ 3,
+ ),
+ (
+ common.Domain(dims=(I, J), ranges=(UnitRange(-5, 2), UnitRange(3, 8))),
+ np.ones((7, 5)),
+ common.Domain(dims=(I,), ranges=(UnitRange(-5, 0),)),
+ 5,
+ (5, 5),
+ 6,
+ ),
+]
+
+
+def adder(i, j):
+ return i + j
+
+
+def test_function_field_broadcast(binary_arithmetic_op, binary_reverse_arithmetic_op):
+ func1 = lambda x, y: x + y
+ func2 = lambda y: 2 * y
+
+ domain1 = common.Domain(dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)))
+ domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),))
+
+ ff1 = funcf.FunctionField(func1, domain1)
+ ff2 = funcf.FunctionField(func2, domain2)
+
+ ops = [binary_arithmetic_op, binary_reverse_arithmetic_op]
+
+ for op in ops:
+ result = op(ff1, ff2)
+
+ assert result.func(5, 10) == op(func1(5, 10), func2(10))
+ assert isinstance(result.ndarray, np.ndarray)
+
+
+def test_function_field_logical_operators(binary_logical_op):
+ func1 = lambda x, y: x > 5
+ func2 = lambda y: y < 10
+
+ domain1 = common.Domain(dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)))
+ domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),))
+
+ ff1 = funcf.FunctionField(func1, domain1)
+ ff2 = funcf.FunctionField(func2, domain2)
+
+ result = binary_logical_op(ff1, ff2)
+
+ assert result.func(5, 10) == binary_logical_op(func1(5, 10), func2(10))
+ assert isinstance(result.ndarray, np.ndarray)
+
+
+@pytest.mark.parametrize(
+ "domain,expected_shape",
+ [
+ (common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))), (10, 10)),
+ ],
+)
+def test_function_field_ndarray(domain, expected_shape):
+ ff = funcf.FunctionField(adder, domain)
+ assert ff.ndarray.shape == expected_shape
+
+ ff_func = lambda *indices: adder(*indices)
+ expected_values = np.fromfunction(ff_func, ff.ndarray.shape)
+ assert np.allclose(ff.ndarray, expected_values)
+
+
+@pytest.mark.parametrize(
+ "domain",
+ [
+ common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))),
+ ],
+)
+def test_function_field_with_field(domain):
+ ff = funcf.FunctionField(adder, domain)
+ field = common.field(np.ones((10, 10)), domain=domain)
+
+ result = ff + field
+ ff_func = lambda *indices: adder(*indices) + 1
+ expected_values = np.fromfunction(ff_func, result.ndarray.shape)
+
+ assert result.ndarray.shape == (10, 10)
+ assert np.allclose(result.ndarray, expected_values)
+
+
+def test_function_field_function_field_op():
+ res = funcf.FunctionField(
+ lambda x, y: x + 42 * y,
+ domain=common.Domain(
+ dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))
+ ),
+ ) + funcf.FunctionField(
+ lambda y: 2 * y, domain=common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),))
+ )
+
+ assert res.func(1, 2) == 89
+
+
+@pytest.fixture
+def function_field():
+ return funcf.FunctionField(
+ adder,
+ domain=common.Domain(
+ dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))
+ ),
+ )
+
+
+def test_function_field_unary(function_field):
+ pos_result = +function_field
+ assert pos_result.func(1, 2) == 3
+
+ neg_result = -function_field
+ assert neg_result.func(1, 2) == -3
+
+ abs_result = abs(function_field)
+ assert abs_result.func(1, 2) == 3
+
+
+def test_function_field_scalar_op(function_field):
+ new = function_field * 5.0
+ assert new.func(1, 2) == 15
+
+
+@pytest.mark.parametrize("func", ["foo", 1.0, 1])
+def test_function_field_invalid_func(func):
+ with pytest.raises(embedded_exceptions.FunctionFieldError, match="Invalid first argument type"):
+ funcf.FunctionField(func)
+
+
+@pytest.mark.parametrize(
+ "domain",
+ [
+ common.Domain(),
+ common.Domain(*((I, UnitRange(1, 10)), (J, UnitRange(5, 10)))),
+ ],
+)
+def test_function_field_invalid_invariant(domain):
+ with pytest.raises(embedded_exceptions.FunctionFieldError, match="Invariant violation"):
+ funcf.FunctionField(lambda *args, x: x, domain)
+
+
+def test_function_field_infinite_range(get_infinite_domain, get_mixed_domain):
+ domains = [get_infinite_domain, get_mixed_domain]
+ for d in domains:
+ with pytest.raises(embedded_exceptions.InfiniteRangeNdarrayError):
+ ff = funcf.FunctionField(adder, d)
+ ff.ndarray
+
+
+def test_unary_logical_op_boolean():
+ boolean_func = lambda x: x % 2 != 0
+ field = funcf.FunctionField(boolean_func, common.Domain((I, UnitRange(1, 10))))
+ assert np.allclose(~field.ndarray, np.invert(np.fromfunction(boolean_func, (9,))))
+
+
+def test_unary_logical_op_scalar():
+ scalar_func = lambda x: x % 2
+ field = funcf.FunctionField(scalar_func, common.Domain((I, UnitRange(1, 10))))
+ with pytest.raises(NotImplementedError):
+ ~field
diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
index 1a38e5245e..67341281eb 100644
--- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
+++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
@@ -11,23 +11,22 @@
# distribution for a copy of the license or check .
#
# SPDX-License-Identifier: GPL-3.0-or-later
-import dataclasses
import itertools
-import math
import operator
-from typing import Callable, Iterable
import numpy as np
import pytest
-from gt4py.next import common, embedded
+from gt4py.next import Dimension, common, embedded
from gt4py.next.common import Dimension, Domain, UnitRange
-from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field
+from gt4py.next.embedded import (
+ exceptions as embedded_exceptions,
+ function_field as funcf,
+ nd_array_field,
+)
from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice
from gt4py.next.ffront import fbuiltins
-from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data
-
IDim = Dimension("IDim")
JDim = Dimension("JDim")
@@ -70,10 +69,23 @@ def unary_logical_op(request):
yield request.param
-def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=None):
+@pytest.fixture(
+ params=[
+ lambda x, y: operator.truediv(y, x), # Reverse true division
+ lambda x, y: operator.add(y, x), # Reverse addition
+ lambda x, y: operator.mul(y, x), # Reverse multiplication
+ lambda x, y: operator.sub(y, x), # Reverse subtraction
+ lambda x, y: operator.floordiv(y, x), # Reverse floor division
+ ]
+)
+def binary_reverse_arithmetic_op(request):
+ yield request.param
+
+
+def _make_base_ndarray_field(arr: np.ndarray, nd_array_implementation, *, domain=None, dtype=None):
if not dtype:
dtype = nd_array_implementation.float32
- buffer = nd_array_implementation.asarray(lst, dtype=dtype)
+ buffer = nd_array_implementation.asarray(arr, dtype=dtype)
if domain is None:
domain = tuple(
(common.Dimension(f"D{i}"), common.UnitRange(0, s)) for i, s in enumerate(buffer.shape)
@@ -84,23 +96,37 @@ def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=No
)
-@pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data())
-def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementation):
+def _make_function_field():
+ def adder(i, j):
+ return i + j
+
+ return funcf.FunctionField(
+ adder,
+ domain=common.Domain(
+ dims=(IDim, JDim), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))
+ ),
+ )
+
+
+normal_dist = np.random.normal(3, 2.5, size=(10,))
+
+
+@pytest.fixture(params=[_make_base_ndarray_field(normal_dist, np), _make_function_field()])
+def all_field_types(request):
+ yield request.param
+
+
+@pytest.mark.parametrize("builtin_name", funcf._UNARY_BUILTINS)
+def test_unary_builtins_for_all_fields(all_field_types, builtin_name):
if builtin_name == "gamma":
# numpy has no gamma function
pytest.xfail("TODO: implement gamma")
- ref_impl: Callable = np.vectorize(math.gamma)
- else:
- ref_impl: Callable = getattr(np, builtin_name)
-
- expected = ref_impl(*[np.asarray(inp, dtype=np.float32) for inp in inputs])
-
- field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs]
- builtin = getattr(fbuiltins, builtin_name)
- result = builtin(*field_inputs)
+ fbuiltin_func = getattr(fbuiltins, builtin_name)
+ result = fbuiltin_func(all_field_types).asnumpy()
- assert np.allclose(result.ndarray, expected)
+ expected = getattr(np, builtin_name)(all_field_types.asnumpy())
+ assert np.allclose(result, expected, equal_nan=True)
def test_where_builtin(nd_array_implementation):
@@ -108,7 +134,9 @@ def test_where_builtin(nd_array_implementation):
true_ = np.asarray([1.0, 2.0], dtype=np.float32)
false_ = np.asarray([3.0, 4.0], dtype=np.float32)
- field_inputs = [_make_field(inp, nd_array_implementation) for inp in [cond, true_, false_]]
+ field_inputs = [
+ _make_base_ndarray_field(inp, nd_array_implementation) for inp in [cond, true_, false_]
+ ]
expected = np.where(cond, true_, false_)
result = fbuiltins.where(*field_inputs)
@@ -148,27 +176,36 @@ def test_where_builtin_with_tuple(nd_array_implementation):
expected0 = np.where(cond, true0, false0)
expected1 = np.where(cond, true1, false1)
- cond_field = _make_field(cond, nd_array_implementation, dtype=bool)
- field_true = tuple(_make_field(inp, nd_array_implementation) for inp in [true0, true1])
- field_false = tuple(_make_field(inp, nd_array_implementation) for inp in [false0, false1])
+ cond_field = _make_base_ndarray_field(cond, nd_array_implementation, dtype=bool)
+ field_true = tuple(
+ _make_base_ndarray_field(inp, nd_array_implementation) for inp in [true0, true1]
+ )
+ field_false = tuple(
+ _make_base_ndarray_field(inp, nd_array_implementation) for inp in [false0, false1]
+ )
result = fbuiltins.where(cond_field, field_true, field_false)
assert np.allclose(result[0].ndarray, expected0)
assert np.allclose(result[1].ndarray, expected1)
-def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation):
+def test_binary_arithmetic_ops(
+ binary_arithmetic_op, binary_reverse_arithmetic_op, nd_array_implementation
+):
inp_a = [-1.0, 4.2, 42]
inp_b = [2.0, 3.0, -3.0]
inputs = [inp_a, inp_b]
- expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs])
+ ops = [binary_arithmetic_op, binary_reverse_arithmetic_op]
- field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs]
+ for op in ops:
+ expected = op(*[np.asarray(inp, dtype=np.float32) for inp in inputs])
- result = binary_arithmetic_op(*field_inputs)
+ field_inputs = [_make_base_ndarray_field(inp, nd_array_implementation) for inp in inputs]
- assert np.allclose(result.ndarray, expected)
+ result = op(*field_inputs)
+
+ assert np.allclose(result.ndarray, expected)
def test_binary_logical_ops(binary_logical_op, nd_array_implementation):
@@ -178,7 +215,9 @@ def test_binary_logical_ops(binary_logical_op, nd_array_implementation):
expected = binary_logical_op(*[np.asarray(inp) for inp in inputs])
- field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs]
+ field_inputs = [
+ _make_base_ndarray_field(inp, nd_array_implementation, dtype=bool) for inp in inputs
+ ]
result = binary_logical_op(*field_inputs)
@@ -193,7 +232,7 @@ def test_unary_logical_ops(unary_logical_op, nd_array_implementation):
expected = unary_logical_op(np.asarray(inp))
- field_input = _make_field(inp, nd_array_implementation, dtype=bool)
+ field_input = _make_base_ndarray_field(inp, nd_array_implementation, dtype=bool)
result = unary_logical_op(field_input)
@@ -205,7 +244,7 @@ def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation):
expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32))
- field_input = _make_field(inp, nd_array_implementation)
+ field_input = _make_base_ndarray_field(inp, nd_array_implementation)
result = unary_arithmetic_op(field_input)
@@ -256,8 +295,8 @@ def test_mixed_fields(product_nd_array_implementation):
expected = np.asarray(inp_a) + np.asarray(inp_b)
- field_inp_a = _make_field(inp_a, first_impl)
- field_inp_b = _make_field(inp_b, second_impl)
+ field_inp_a = _make_base_ndarray_field(inp_a, first_impl)
+ field_inp_b = _make_base_ndarray_field(inp_b, second_impl)
result = field_inp_a + field_inp_b
assert np.allclose(result.ndarray, expected)
@@ -274,9 +313,9 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field:
expected = np.asarray(inp_a) * np.asarray(inp_b) + np.asarray(inp_c)
- field_inp_a = _make_field(inp_a, np)
- field_inp_b = _make_field(inp_b, np)
- field_inp_c = _make_field(inp_c, np)
+ field_inp_a = _make_base_ndarray_field(inp_a, np)
+ field_inp_b = _make_base_ndarray_field(inp_b, np)
+ field_inp_c = _make_base_ndarray_field(inp_c, np)
result = fma(field_inp_a, field_inp_b, field_inp_c)
assert np.allclose(result.ndarray, expected)