Skip to content

Commit

Permalink
feat[next]: Support for direct field operator call with domain arg (#…
Browse files Browse the repository at this point in the history
…1779)

Adds support for directly calling a field operator with a domain
argument, which was previously only supported inside of a program. Many
field operators in icon4py use the domain argument resulting in
excessive amounts of boilerplate programs that can be removed now.
```python
@field_operator
def testee(inp: IField) -> IField:
    return inp

testee(inp, domain={IDim: (0, 10)})
```

Support in the dace backend is missing and will be added in a seperate
PR.

---------

Co-authored-by: Edoardo Paone <edoardo.paone@cscs.ch>
  • Loading branch information
tehrengruber and edopao authored Jan 18, 2025
1 parent 1b17202 commit 517e1e9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,10 @@ def __call__(self, *args, **kwargs) -> None:
if "out" not in kwargs:
raise errors.MissingArgumentError(None, "out", True)
out = kwargs.pop("out")
if "domain" in kwargs:
domain = common.domain(kwargs.pop("domain"))
out = out[domain]

args, kwargs = type_info.canonicalize_arguments(
self.foast_stage.foast_node.type, args, kwargs
)
Expand Down
12 changes: 9 additions & 3 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def verify(
fieldview_prog: decorator.FieldOperator | decorator.Program,
*args: FieldViewArg,
ref: ReferenceValue,
domain: Optional[dict[common.Dimension, tuple[int, int]]] = None,
out: Optional[FieldViewInout] = None,
inout: Optional[FieldViewInout] = None,
offset_provider: Optional[OffsetProvider] = None,
Expand All @@ -405,6 +406,8 @@ def verify(
or tuple of fields here and they will be compared to ``ref`` under
the assumption that the fieldview code stores its results in
them.
domain: If given will be passed to the fieldview code as ``domain=``
keyword argument.
offset_provider: An override for the test case's offset_provider.
Use with care!
comparison: A comparison function, which will be called as
Expand All @@ -414,10 +417,13 @@ def verify(
used as an argument to the fieldview program and compared against ``ref``.
Else, ``inout`` will not be passed and compared to ``ref``.
"""
kwargs = {}
if out:
run(case, fieldview_prog, *args, out=out, offset_provider=offset_provider)
else:
run(case, fieldview_prog, *args, offset_provider=offset_provider)
kwargs["out"] = out
if domain:
kwargs["domain"] = domain

run(case, fieldview_prog, *args, **kwargs, offset_provider=offset_provider)

out_comp = out or inout
assert out_comp is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import pytest

from gt4py.next import errors
from gt4py.next import errors, common, constructors
from gt4py.next.ffront.decorator import field_operator, program, scan_operator
from gt4py.next.ffront.fbuiltins import broadcast, int32

Expand Down Expand Up @@ -296,3 +296,21 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te
)
is not None
)


@pytest.mark.uses_origin
def test_direct_fo_call_with_domain_arg(cartesian_case):
@field_operator
def testee(inp: IField) -> IField:
return inp

size = cartesian_case.default_sizes[IDim]
inp = cases.allocate(cartesian_case, testee, "inp").unique()()
out = cases.allocate(
cartesian_case, testee, cases.RETURN, strategy=cases.ConstInitializer(42)
)()
ref = inp.array_ns.zeros(size)
ref[0] = ref[-1] = 42
ref[1:-1] = inp.ndarray[1:-1]

cases.verify(cartesian_case, testee, inp, out=out, domain={IDim: (1, size - 1)}, ref=ref)

0 comments on commit 517e1e9

Please sign in to comment.