Skip to content

Commit

Permalink
Implement invert correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Sep 13, 2023
1 parent f243469 commit f248c69
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
12 changes: 9 additions & 3 deletions src/gt4py/next/embedded/function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __post_init__(self):
f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the provided function ({params})",
)

@property
def dtype(self) -> core_defs.DType[core_defs.ScalarT]:
return core_defs.dtype(self.ndarray.dtype.type)

def restrict(self, index: common.AnyIndexSpec) -> FunctionField:
new_domain = embedded_common.sub_domain(self.domain, index)
return self.__class__(self.func, new_domain)
Expand Down Expand Up @@ -206,12 +210,14 @@ def __neg__(self) -> common.Field:
def __abs__(self) -> common.Field:
return self._unary_op(abs)

def __invert__(self) -> common.Field:
if self.dtype == core_defs.BoolDType():
return self._unary_op(operator.invert)
raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.")

def __call__(self, *args, **kwargs) -> common.Field:
return self.func(*args, **kwargs)

def __invert__(self) -> common.Field:
raise NotImplementedError("Method invert not implemented")

def remap(self, *args, **kwargs) -> common.Field:
raise NotImplementedError("Method remap not implemented")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import math
import operator
import array

import numpy as np
import pytest
Expand Down Expand Up @@ -310,3 +309,16 @@ def test_function_field_builtins(function_field, builtin_name):
assert math.isnan(np.__getattribute__(builtin_name)(3))
else:
assert result == np.__getattribute__(builtin_name)(3)


def test_unary_logical_op_boolean():
boolean_func = lambda x: x % 2 != 0
field = funcf.FunctionField(boolean_func, common.Domain((I, UnitRange(1, 10))))
assert np.allclose(~field.ndarray, np.invert(np.fromfunction(boolean_func, (9,))))


def test_unary_logical_op_scalar():
scalar_func = lambda x: x % 2
field = funcf.FunctionField(scalar_func, common.Domain((I, UnitRange(1, 10))))
with pytest.raises(NotImplementedError):
~field

0 comments on commit f248c69

Please sign in to comment.