Skip to content

Commit

Permalink
Customisable ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Sep 12, 2023
1 parent 4ab5ab5 commit 1ca17f4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/gt4py/next/embedded/function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dataclasses
import inspect
import operator
from typing import Any, Callable, TypeGuard, overload
from typing import Any, Callable, TypeGuard, overload, Optional

import numpy as np

Expand Down Expand Up @@ -56,7 +56,7 @@ class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldB
>>> domain = common.Domain((I, common.UnitRange(0, 5)))
>>> func = lambda i: i ** 2
>>> field = FunctionField(func, domain)
>>> ndarray = field.ndarray
>>> ndarray = field.ndarray()
>>> expected_ndarray = np.fromfunction(func, (5,))
>>> np.array_equal(ndarray, expected_ndarray)
True
Expand Down Expand Up @@ -93,13 +93,17 @@ def restrict(self, index: common.AnyIndexSpec) -> FunctionField:
__getitem__ = restrict

@property
def ndarray(self) -> core_defs.NDArrayObject | int | float:
def ndarray(self):
return self.as_array()

def as_array(self, func: Optional[Callable[[core_defs.NDArrayObject], Any]] = None) -> core_defs.NDArrayObject | int | float:
if not self.domain.is_finite():
raise embedded_exceptions.InfiniteRangeNdarrayError(
self.__class__.__name__, self.domain
)
shape = [len(rng) for rng in self.domain.ranges]
return np.fromfunction(self.func, shape)
_ndarray = np.fromfunction(self.func, shape)
return _ndarray if func is None else func(_ndarray)

def _handle_function_field_op(self, other: FunctionField, op: Callable) -> FunctionField:
domain_intersection = self.domain & other.domain
Expand Down
12 changes: 12 additions & 0 deletions tests/next_tests/unit_tests/embedded_tests/test_function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import math
import operator
import array

import numpy as np
import pytest
Expand Down Expand Up @@ -313,3 +314,14 @@ def test_function_field_builtins(function_field, builtin_name):
assert math.isnan(np.__getattribute__(builtin_name)(3))
else:
assert result == np.__getattribute__(builtin_name)(3)


def test_ndarray_with_transform(function_field):
def transform_to_array(arr):
return array.array('d', arr.flatten())

result = function_field.as_array(func=transform_to_array)

assert isinstance(result, array.array)
assert len(result) == 45
assert result.typecode == 'd'

0 comments on commit 1ca17f4

Please sign in to comment.