diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index ed33aedc4c..a04d78516e 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -274,6 +274,10 @@ def __init__( def __len__(self) -> int: return len(self.ranges) + @property + def shape(self) -> tuple[int, ...]: + return tuple(len(r) for r in self.ranges) + @overload def __getitem__(self, index: int) -> NamedRange: ... diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 8cd2d8ce15..3670d5eac7 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -70,7 +70,7 @@ def _absolute_sub_domain( raise embedded_exceptions.IndexOutOfBounds(domain=domain, indices=index) for i, (dim, rng) in enumerate(domain): - if (pos := _find_index_of_dim(dim, index)) is not None: + if (pos := find_index_of_dim(dim, index)) is not None: named_idx = index[pos] idx = named_idx[1] if isinstance(idx, common.UnitRange): @@ -122,7 +122,7 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit return common.UnitRange(start, stop) -def _find_index_of_dim( +def find_index_of_dim( dim: common.Dimension, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], ) -> Optional[int]: @@ -132,12 +132,12 @@ def _find_index_of_dim( return None -def _broadcast_domain( +def broadcast_domain( field: common.Field, new_dimensions: tuple[common.Dimension, ...] ) -> Sequence[common.NamedRange]: named_ranges = [] for dim in new_dimensions: - if (pos := _find_index_of_dim(dim, field.domain)) is not None: + if (pos := find_index_of_dim(dim, field.domain)) is not None: named_ranges.append((dim, field.domain[pos][1])) else: named_ranges.append( diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index 8f6b26e3bb..d70bb6e206 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -43,7 +43,7 @@ def __init__( class EmptyDomainIndexError(gt4py_exceptions.GT4PyError): - index: common.AnyIndexSpec + cls_name: str def __init__(self, cls_name: str): super().__init__(f"Error in `{cls_name}`: Cannot index `{cls_name}` with an empty domain.") @@ -51,6 +51,9 @@ def __init__(self, cls_name: str): class FunctionFieldError(gt4py_exceptions.GT4PyError): + cls_name: str + msg: str + def __init__(self, cls_name: str, msg: str): super().__init__(f"Error in `{cls_name}`: {msg}.") self.cls_name = cls_name @@ -58,6 +61,9 @@ def __init__(self, cls_name: str, msg: str): class InfiniteRangeNdarrayError(gt4py_exceptions.GT4PyError): + cls_name: type + domain: common.Domain + def __init__(self, cls_name: str, domain: common.Domain): super().__init__( f"Error in `{cls_name}`: Cannot construct an ndarray with an infinite range in domain: `{domain}`." diff --git a/src/gt4py/next/embedded/function_field.py b/src/gt4py/next/embedded/function_field.py index 1ba84eac46..b1cd2b040c 100644 --- a/src/gt4py/next/embedded/function_field.py +++ b/src/gt4py/next/embedded/function_field.py @@ -43,8 +43,6 @@ class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldB func (Callable): The callable function that generates field values. domain (common.Domain, optional): The domain over which the function is defined. Defaults to an empty domain. - _skip_invariant (bool, optional): Internal flag to skip invariant checks. - Defaults to False. Examples: Create a FunctionField and compute its ndarray: @@ -56,7 +54,7 @@ class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldB >>> domain = common.Domain((I, common.UnitRange(0, 5))) >>> func = lambda i: i ** 2 >>> field = FunctionField(func, domain) - >>> ndarray = field.ndarray() + >>> ndarray = field.ndarray >>> expected_ndarray = np.fromfunction(func, (5,)) >>> np.array_equal(ndarray, expected_ndarray) True @@ -64,7 +62,6 @@ class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldB func: Callable domain: common.Domain = common.Domain() - _skip_invariant: bool = False def __post_init__(self): if not callable(self.func): @@ -73,18 +70,17 @@ def __post_init__(self): f"Invalid first argument type: Expected a function but got {self.func}", ) - if not self._skip_invariant: - if __debug__: - try: - num_params = len(self.domain) - target_shape = tuple(1 for _ in range(num_params)) - np.fromfunction(self.func, target_shape) - except Exception as e: - params = _get_params(self.func) - raise embedded_exceptions.FunctionFieldError( - self.__class__.__name__, - f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the provided function ({params})", - ) + if __debug__: + try: + num_params = len(self.domain) + target_shape = tuple(1 for _ in range(num_params)) + np.fromfunction(self.func, target_shape) + except Exception as e: + params = _get_params(self.func) + raise embedded_exceptions.FunctionFieldError( + self.__class__.__name__, + f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the provided function ({params})", + ) def restrict(self, index: common.AnyIndexSpec) -> FunctionField: new_domain = embedded_common.sub_domain(self.domain, index) @@ -93,17 +89,12 @@ def restrict(self, index: common.AnyIndexSpec) -> FunctionField: __getitem__ = restrict @property - def ndarray(self): - return self.as_array() - - def as_array(self, func: Optional[Callable[[core_defs.NDArrayObject], Any]] = None) -> core_defs.NDArrayObject | int | float: + def ndarray(self) -> core_defs.NDArrayObject | int | float: if not self.domain.is_finite(): raise embedded_exceptions.InfiniteRangeNdarrayError( self.__class__.__name__, self.domain ) - shape = [len(rng) for rng in self.domain.ranges] - _ndarray = np.fromfunction(self.func, shape) - return _ndarray if func is None else func(_ndarray) + return np.fromfunction(self.func, self.domain.shape) def _handle_function_field_op(self, other: FunctionField, op: Callable) -> FunctionField: domain_intersection = self.domain & other.domain @@ -112,7 +103,6 @@ def _handle_function_field_op(self, other: FunctionField, op: Callable) -> Funct return self.__class__( _compose(op, broadcasted_self, broadcasted_other), domain_intersection, - _skip_invariant=True, ) def _handle_scalar_op(self, other: FunctionField, op: Callable) -> FunctionField: @@ -120,7 +110,7 @@ def new_func(*args): return op(self.func(*args), other) return self.__class__( - new_func, self.domain, _skip_invariant=True + new_func, self.domain ) # skip invariant as we cannot deduce number of args @overload @@ -140,7 +130,7 @@ def _binary_operation(self, op, other): return op(other, self) def _unary_op(self, op: Callable) -> FunctionField: - return self.__class__(_compose(op, self), self.domain, _skip_invariant=True) + return self.__class__(_compose(op, self), self.domain) def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: return self._binary_operation(operator.add, other) @@ -213,15 +203,15 @@ def __pos__(self) -> common.Field: def __neg__(self) -> common.Field: return self._unary_op(operator.neg) - def __invert__(self) -> common.Field: - return self._unary_op(operator.invert) - def __abs__(self) -> common.Field: return self._unary_op(abs) def __call__(self, *args, **kwargs) -> common.Field: return self.func(*args, **kwargs) + def __invert__(self) -> common.Field: + raise NotImplementedError("Method invert not implemented") + def remap(self, *args, **kwargs) -> common.Field: raise NotImplementedError("Method remap not implemented") @@ -231,12 +221,12 @@ def _compose(operation: Callable, *fields: FunctionField) -> Callable: def _broadcast(field: FunctionField, dims: tuple[common.Dimension, ...]) -> FunctionField: - def broadcasted_func(*args: int): + def broadcasted_func(*args: int | core_defs.NDArrayObject): selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims] return field.func(*selected_args) - named_ranges = embedded_common._broadcast_domain(field, dims) - return FunctionField(broadcasted_func, common.Domain(*named_ranges), _skip_invariant=True) + named_ranges = embedded_common.broadcast_domain(field, dims) + return FunctionField(broadcasted_func, common.Domain(*named_ranges)) def _is_nd_array(other: Any) -> TypeGuard[nd._BaseNdArrayField]: @@ -246,7 +236,7 @@ def _is_nd_array(other: Any) -> TypeGuard[nd._BaseNdArrayField]: def constant_field( value: core_defs.ScalarT, domain: common.Domain = common.Domain() ) -> common.Field: - return FunctionField(lambda *args: value, domain, _skip_invariant=True) + return FunctionField(lambda *args: value, domain) def _compose_function_field_with_builtin(builtin_name: str) -> Callable: @@ -259,7 +249,7 @@ def _composed_function_field(field: FunctionField) -> FunctionField: builtin_func = getattr(np, builtin_name) new_func = lambda *args: builtin_func(field.func(*args)) - new_field: FunctionField = FunctionField(new_func, field.domain, _skip_invariant=True) + new_field: FunctionField = FunctionField(new_func, field.domain) return new_field return _composed_function_field diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index dbe6d02c4f..bfee55249f 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -25,7 +25,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.embedded import common as embedded_common -from gt4py.next.embedded.common import _broadcast_domain, _find_index_of_dim +from gt4py.next.embedded.common import broadcast_domain, find_index_of_dim from gt4py.next.ffront import fbuiltins @@ -325,7 +325,7 @@ def __setitem__( def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: domain_slice = _compute_domain_slice(field, new_dimensions) - named_ranges = _broadcast_domain(field, new_dimensions) + named_ranges = broadcast_domain(field, new_dimensions) ndarray_ = field.ndarray # handle case where we have a constant FunctionField where field.ndarray is a scalar @@ -368,7 +368,7 @@ def _get_slices_from_domain_slice( slice_indices: list[slice | common.IntIndex] = [] for pos_old, (dim, _) in enumerate(domain): - if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None: + if (pos := embedded_common.find_index_of_dim(dim, domain_slice)) is not None: index_or_range = domain_slice[pos][1] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: @@ -411,7 +411,7 @@ def _compute_domain_slice( ) -> Sequence[slice | None]: domain_slice: list[slice | None] = [] for dim in new_dimensions: - if _find_index_of_dim(dim, field.domain) is not None: + if find_index_of_dim(dim, field.domain) is not None: domain_slice.append(slice(None)) else: domain_slice.append(np.newaxis) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 3a4f5fd66b..6fc2337c16 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -143,27 +143,27 @@ def test_sub_domain(domain, index, expected): @pytest.fixture -def finite_domain(): +def get_finite_domain(): return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) @pytest.fixture -def infinite_domain(): +def get_infinite_domain(): return common.Domain((I, UnitRange.infinity()), (J, UnitRange.infinity())) @pytest.fixture -def mixed_domain(): +def get_mixed_domain(): return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange.infinity())) -def test_finite_domain_is_finite(finite_domain): - assert finite_domain.is_finite() == True +def test_finite_domain_is_finite(get_finite_domain): + assert get_finite_domain.is_finite() -def test_infinite_domain_is_finite(infinite_domain): - assert infinite_domain.is_finite() == False +def test_infinite_domain_is_finite(get_infinite_domain): + assert not get_infinite_domain.is_finite() -def test_mixed_domain_is_finite(mixed_domain): - assert mixed_domain.is_finite() == False +def test_mixed_domain_is_finite(get_mixed_domain): + assert not get_mixed_domain.is_finite() diff --git a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py index 4bae018bd2..85eff32a90 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py @@ -23,7 +23,7 @@ from gt4py.next.embedded import exceptions as embedded_exceptions, function_field as funcf from gt4py.next import fbuiltins -from .test_common import infinite_domain, mixed_domain +from .test_common import get_infinite_domain, get_mixed_domain @@ -42,7 +42,6 @@ def rfloordiv(x, y): operator.mul, operator.truediv, operator.floordiv, - lambda x, y: operator.truediv(y, x), operator.pow, lambda x, y: operator.truediv(y, x), # Reverse true division lambda x, y: operator.add(y, x), # Reverse addition @@ -263,9 +262,6 @@ def test_function_field_unary(function_field): neg_result = -function_field assert neg_result.func(1, 2) == -3 - invert_result = ~function_field - assert invert_result.func(1, 2) == -4 - abs_result = abs(function_field) assert abs_result.func(1, 2) == 3 @@ -293,8 +289,8 @@ def test_function_field_invalid_invariant(domain): funcf.FunctionField(lambda *args, x: x, domain) -def test_function_field_infinite_range(infinite_domain, mixed_domain): - domains = [infinite_domain, mixed_domain] +def test_function_field_infinite_range(get_infinite_domain, get_mixed_domain): + domains = [get_infinite_domain, get_mixed_domain] for d in domains: with pytest.raises(embedded_exceptions.InfiniteRangeNdarrayError): ff = funcf.FunctionField(adder, d) @@ -314,14 +310,3 @@ def test_function_field_builtins(function_field, builtin_name): assert math.isnan(np.__getattribute__(builtin_name)(3)) else: assert result == np.__getattribute__(builtin_name)(3) - - -def test_ndarray_with_transform(function_field): - def transform_to_array(arr): - return array.array('d', arr.flatten()) - - result = function_field.as_array(func=transform_to_array) - - assert isinstance(result, array.array) - assert len(result) == 45 - assert result.typecode == 'd'