Skip to content

Commit

Permalink
Add improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Sep 13, 2023
1 parent 1ca17f4 commit f243469
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 70 deletions.
4 changes: 4 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def __init__(
def __len__(self) -> int:
return len(self.ranges)

@property
def shape(self) -> tuple[int, ...]:
return tuple(len(r) for r in self.ranges)

@overload
def __getitem__(self, index: int) -> NamedRange:
...
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _absolute_sub_domain(
raise embedded_exceptions.IndexOutOfBounds(domain=domain, indices=index)

for i, (dim, rng) in enumerate(domain):
if (pos := _find_index_of_dim(dim, index)) is not None:
if (pos := find_index_of_dim(dim, index)) is not None:
named_idx = index[pos]
idx = named_idx[1]
if isinstance(idx, common.UnitRange):
Expand Down Expand Up @@ -122,7 +122,7 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit
return common.UnitRange(start, stop)


def _find_index_of_dim(
def find_index_of_dim(
dim: common.Dimension,
domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any],
) -> Optional[int]:
Expand All @@ -132,12 +132,12 @@ def _find_index_of_dim(
return None


def _broadcast_domain(
def broadcast_domain(
field: common.Field, new_dimensions: tuple[common.Dimension, ...]
) -> Sequence[common.NamedRange]:
named_ranges = []
for dim in new_dimensions:
if (pos := _find_index_of_dim(dim, field.domain)) is not None:
if (pos := find_index_of_dim(dim, field.domain)) is not None:
named_ranges.append((dim, field.domain[pos][1]))
else:
named_ranges.append(
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/embedded/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,27 @@ def __init__(


class EmptyDomainIndexError(gt4py_exceptions.GT4PyError):
index: common.AnyIndexSpec
cls_name: str

def __init__(self, cls_name: str):
super().__init__(f"Error in `{cls_name}`: Cannot index `{cls_name}` with an empty domain.")
self.cls_name = cls_name


class FunctionFieldError(gt4py_exceptions.GT4PyError):
cls_name: str
msg: str

def __init__(self, cls_name: str, msg: str):
super().__init__(f"Error in `{cls_name}`: {msg}.")
self.cls_name = cls_name
self.msg = msg


class InfiniteRangeNdarrayError(gt4py_exceptions.GT4PyError):
cls_name: type
domain: common.Domain

def __init__(self, cls_name: str, domain: common.Domain):
super().__init__(
f"Error in `{cls_name}`: Cannot construct an ndarray with an infinite range in domain: `{domain}`."
Expand Down
58 changes: 24 additions & 34 deletions src/gt4py/next/embedded/function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldB
func (Callable): The callable function that generates field values.
domain (common.Domain, optional): The domain over which the function is defined.
Defaults to an empty domain.
_skip_invariant (bool, optional): Internal flag to skip invariant checks.
Defaults to False.
Examples:
Create a FunctionField and compute its ndarray:
Expand All @@ -56,15 +54,14 @@ class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldB
>>> domain = common.Domain((I, common.UnitRange(0, 5)))
>>> func = lambda i: i ** 2
>>> field = FunctionField(func, domain)
>>> ndarray = field.ndarray()
>>> ndarray = field.ndarray
>>> expected_ndarray = np.fromfunction(func, (5,))
>>> np.array_equal(ndarray, expected_ndarray)
True
"""

func: Callable
domain: common.Domain = common.Domain()
_skip_invariant: bool = False

def __post_init__(self):
if not callable(self.func):
Expand All @@ -73,18 +70,17 @@ def __post_init__(self):
f"Invalid first argument type: Expected a function but got {self.func}",
)

if not self._skip_invariant:
if __debug__:
try:
num_params = len(self.domain)
target_shape = tuple(1 for _ in range(num_params))
np.fromfunction(self.func, target_shape)
except Exception as e:
params = _get_params(self.func)
raise embedded_exceptions.FunctionFieldError(
self.__class__.__name__,
f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the provided function ({params})",
)
if __debug__:
try:
num_params = len(self.domain)
target_shape = tuple(1 for _ in range(num_params))
np.fromfunction(self.func, target_shape)
except Exception as e:
params = _get_params(self.func)
raise embedded_exceptions.FunctionFieldError(
self.__class__.__name__,
f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the provided function ({params})",
)

def restrict(self, index: common.AnyIndexSpec) -> FunctionField:
new_domain = embedded_common.sub_domain(self.domain, index)
Expand All @@ -93,17 +89,12 @@ def restrict(self, index: common.AnyIndexSpec) -> FunctionField:
__getitem__ = restrict

@property
def ndarray(self):
return self.as_array()

def as_array(self, func: Optional[Callable[[core_defs.NDArrayObject], Any]] = None) -> core_defs.NDArrayObject | int | float:
def ndarray(self) -> core_defs.NDArrayObject | int | float:
if not self.domain.is_finite():
raise embedded_exceptions.InfiniteRangeNdarrayError(
self.__class__.__name__, self.domain
)
shape = [len(rng) for rng in self.domain.ranges]
_ndarray = np.fromfunction(self.func, shape)
return _ndarray if func is None else func(_ndarray)
return np.fromfunction(self.func, self.domain.shape)

def _handle_function_field_op(self, other: FunctionField, op: Callable) -> FunctionField:
domain_intersection = self.domain & other.domain
Expand All @@ -112,15 +103,14 @@ def _handle_function_field_op(self, other: FunctionField, op: Callable) -> Funct
return self.__class__(
_compose(op, broadcasted_self, broadcasted_other),
domain_intersection,
_skip_invariant=True,
)

def _handle_scalar_op(self, other: FunctionField, op: Callable) -> FunctionField:
def new_func(*args):
return op(self.func(*args), other)

return self.__class__(
new_func, self.domain, _skip_invariant=True
new_func, self.domain
) # skip invariant as we cannot deduce number of args

@overload
Expand All @@ -140,7 +130,7 @@ def _binary_operation(self, op, other):
return op(other, self)

def _unary_op(self, op: Callable) -> FunctionField:
return self.__class__(_compose(op, self), self.domain, _skip_invariant=True)
return self.__class__(_compose(op, self), self.domain)

def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(operator.add, other)
Expand Down Expand Up @@ -213,15 +203,15 @@ def __pos__(self) -> common.Field:
def __neg__(self) -> common.Field:
return self._unary_op(operator.neg)

def __invert__(self) -> common.Field:
return self._unary_op(operator.invert)

def __abs__(self) -> common.Field:
return self._unary_op(abs)

def __call__(self, *args, **kwargs) -> common.Field:
return self.func(*args, **kwargs)

def __invert__(self) -> common.Field:
raise NotImplementedError("Method invert not implemented")

def remap(self, *args, **kwargs) -> common.Field:
raise NotImplementedError("Method remap not implemented")

Expand All @@ -231,12 +221,12 @@ def _compose(operation: Callable, *fields: FunctionField) -> Callable:


def _broadcast(field: FunctionField, dims: tuple[common.Dimension, ...]) -> FunctionField:
def broadcasted_func(*args: int):
def broadcasted_func(*args: int | core_defs.NDArrayObject):
selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims]
return field.func(*selected_args)

named_ranges = embedded_common._broadcast_domain(field, dims)
return FunctionField(broadcasted_func, common.Domain(*named_ranges), _skip_invariant=True)
named_ranges = embedded_common.broadcast_domain(field, dims)
return FunctionField(broadcasted_func, common.Domain(*named_ranges))


def _is_nd_array(other: Any) -> TypeGuard[nd._BaseNdArrayField]:
Expand All @@ -246,7 +236,7 @@ def _is_nd_array(other: Any) -> TypeGuard[nd._BaseNdArrayField]:
def constant_field(
value: core_defs.ScalarT, domain: common.Domain = common.Domain()
) -> common.Field:
return FunctionField(lambda *args: value, domain, _skip_invariant=True)
return FunctionField(lambda *args: value, domain)


def _compose_function_field_with_builtin(builtin_name: str) -> Callable:
Expand All @@ -259,7 +249,7 @@ def _composed_function_field(field: FunctionField) -> FunctionField:

builtin_func = getattr(np, builtin_name)
new_func = lambda *args: builtin_func(field.func(*args))
new_field: FunctionField = FunctionField(new_func, field.domain, _skip_invariant=True)
new_field: FunctionField = FunctionField(new_func, field.domain)
return new_field
return _composed_function_field

Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from gt4py._core import definitions as core_defs
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common
from gt4py.next.embedded.common import _broadcast_domain, _find_index_of_dim
from gt4py.next.embedded.common import broadcast_domain, find_index_of_dim
from gt4py.next.ffront import fbuiltins


Expand Down Expand Up @@ -325,7 +325,7 @@ def __setitem__(

def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
domain_slice = _compute_domain_slice(field, new_dimensions)
named_ranges = _broadcast_domain(field, new_dimensions)
named_ranges = broadcast_domain(field, new_dimensions)
ndarray_ = field.ndarray

# handle case where we have a constant FunctionField where field.ndarray is a scalar
Expand Down Expand Up @@ -368,7 +368,7 @@ def _get_slices_from_domain_slice(
slice_indices: list[slice | common.IntIndex] = []

for pos_old, (dim, _) in enumerate(domain):
if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None:
if (pos := embedded_common.find_index_of_dim(dim, domain_slice)) is not None:
index_or_range = domain_slice[pos][1]
slice_indices.append(_compute_slice(index_or_range, domain, pos_old))
else:
Expand Down Expand Up @@ -411,7 +411,7 @@ def _compute_domain_slice(
) -> Sequence[slice | None]:
domain_slice: list[slice | None] = []
for dim in new_dimensions:
if _find_index_of_dim(dim, field.domain) is not None:
if find_index_of_dim(dim, field.domain) is not None:
domain_slice.append(slice(None))
else:
domain_slice.append(np.newaxis)
Expand Down
18 changes: 9 additions & 9 deletions tests/next_tests/unit_tests/embedded_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,27 +143,27 @@ def test_sub_domain(domain, index, expected):


@pytest.fixture
def finite_domain():
def get_finite_domain():
return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4)))


@pytest.fixture
def infinite_domain():
def get_infinite_domain():
return common.Domain((I, UnitRange.infinity()), (J, UnitRange.infinity()))


@pytest.fixture
def mixed_domain():
def get_mixed_domain():
return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange.infinity()))


def test_finite_domain_is_finite(finite_domain):
assert finite_domain.is_finite() == True
def test_finite_domain_is_finite(get_finite_domain):
assert get_finite_domain.is_finite()


def test_infinite_domain_is_finite(infinite_domain):
assert infinite_domain.is_finite() == False
def test_infinite_domain_is_finite(get_infinite_domain):
assert not get_infinite_domain.is_finite()


def test_mixed_domain_is_finite(mixed_domain):
assert mixed_domain.is_finite() == False
def test_mixed_domain_is_finite(get_mixed_domain):
assert not get_mixed_domain.is_finite()
21 changes: 3 additions & 18 deletions tests/next_tests/unit_tests/embedded_tests/test_function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from gt4py.next.embedded import exceptions as embedded_exceptions, function_field as funcf
from gt4py.next import fbuiltins

from .test_common import infinite_domain, mixed_domain
from .test_common import get_infinite_domain, get_mixed_domain



Expand All @@ -42,7 +42,6 @@ def rfloordiv(x, y):
operator.mul,
operator.truediv,
operator.floordiv,
lambda x, y: operator.truediv(y, x),
operator.pow,
lambda x, y: operator.truediv(y, x), # Reverse true division
lambda x, y: operator.add(y, x), # Reverse addition
Expand Down Expand Up @@ -263,9 +262,6 @@ def test_function_field_unary(function_field):
neg_result = -function_field
assert neg_result.func(1, 2) == -3

invert_result = ~function_field
assert invert_result.func(1, 2) == -4

abs_result = abs(function_field)
assert abs_result.func(1, 2) == 3

Expand Down Expand Up @@ -293,8 +289,8 @@ def test_function_field_invalid_invariant(domain):
funcf.FunctionField(lambda *args, x: x, domain)


def test_function_field_infinite_range(infinite_domain, mixed_domain):
domains = [infinite_domain, mixed_domain]
def test_function_field_infinite_range(get_infinite_domain, get_mixed_domain):
domains = [get_infinite_domain, get_mixed_domain]
for d in domains:
with pytest.raises(embedded_exceptions.InfiniteRangeNdarrayError):
ff = funcf.FunctionField(adder, d)
Expand All @@ -314,14 +310,3 @@ def test_function_field_builtins(function_field, builtin_name):
assert math.isnan(np.__getattribute__(builtin_name)(3))
else:
assert result == np.__getattribute__(builtin_name)(3)


def test_ndarray_with_transform(function_field):
def transform_to_array(arr):
return array.array('d', arr.flatten())

result = function_field.as_array(func=transform_to_array)

assert isinstance(result, array.array)
assert len(result) == 45
assert result.typecode == 'd'

0 comments on commit f243469

Please sign in to comment.