Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Support for direct field operator call with domain arg #1779

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
365c942
Support for direct field operator call with domain arg
tehrengruber Dec 10, 2024
7bb21fe
Merge branch 'main' into direct_fo_call_with_domain_arg
tehrengruber Dec 10, 2024
aed4d1e
Support for calling a program with field arguments whose domain does …
tehrengruber Dec 10, 2024
1e0aa93
Merge branch 'field_arg_with_non_zero_domain_start' into direct_fo_ca…
tehrengruber Dec 10, 2024
f722c14
Add test for input arg with different domain
tehrengruber Dec 11, 2024
c5a61e9
Fix format
tehrengruber Dec 11, 2024
9e09c86
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Dec 11, 2024
9deb814
update dace backend
edopao Dec 11, 2024
61feb99
Fix failing tests
tehrengruber Jan 10, 2025
30a4911
Merge remote-tracking branch 'origin_tehrengruber/field_arg_with_non_…
tehrengruber Jan 10, 2025
9d97ea7
Merge branch 'field_arg_with_non_zero_domain_start' into direct_fo_ca…
tehrengruber Jan 10, 2025
3f15911
Disable in dace backend
tehrengruber Jan 10, 2025
fd95ff4
Merge branch 'main' into direct_fo_call_with_domain_arg
tehrengruber Jan 10, 2025
052c54b
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Jan 10, 2025
d1009be
Fix gpu tests
tehrengruber Jan 10, 2025
0d903cc
Address review comments
tehrengruber Jan 10, 2025
7b77c9f
Merge remote-tracking branch 'origin_tehrengruber/field_arg_with_non_…
tehrengruber Jan 10, 2025
a6cf988
Merge origin/main
tehrengruber Jan 10, 2025
858a573
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Jan 14, 2025
da65ca1
Merge remote-tracking branch 'origin/main' into field_arg_with_non_ze…
edopao Jan 15, 2025
e6e640c
dace support for domain range and field origin
edopao Jan 15, 2025
a9f67f9
minor edit
edopao Jan 15, 2025
b97232a
Revert "minor edit"
edopao Jan 16, 2025
56ec88d
Revert "dace support for domain range and field origin"
edopao Jan 16, 2025
ad68fac
Merge remote-tracking branch 'origin/main' into field_arg_with_non_ze…
edopao Jan 16, 2025
a28fbf3
skip dace orchestration tests
edopao Jan 16, 2025
9637866
skip dace test_halo_exchange_helper_attrs
edopao Jan 16, 2025
6f58be6
Merge direct_fo_call_with_domain_arg
tehrengruber Jan 17, 2025
63cb251
Merge remote-tracking branch 'origin/main' into direct_fo_call_with_d…
tehrengruber Jan 17, 2025
9fb7d2e
Fix pytest mark
tehrengruber Jan 17, 2025
bbdae9d
Fix pytest mark
tehrengruber Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,10 @@ def __call__(self, *args, **kwargs) -> None:
if "out" not in kwargs:
raise errors.MissingArgumentError(None, "out", True)
out = kwargs.pop("out")
if "domain" in kwargs:
domain = common.domain(kwargs.pop("domain"))
out = out[domain]

args, kwargs = type_info.canonicalize_arguments(
self.foast_stage.foast_node.type, args, kwargs
)
Expand Down
12 changes: 9 additions & 3 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def verify(
fieldview_prog: decorator.FieldOperator | decorator.Program,
*args: FieldViewArg,
ref: ReferenceValue,
domain: Optional[dict[common.Dimension, tuple[int, int]]] = None,
out: Optional[FieldViewInout] = None,
inout: Optional[FieldViewInout] = None,
offset_provider: Optional[OffsetProvider] = None,
Expand All @@ -405,6 +406,8 @@ def verify(
or tuple of fields here and they will be compared to ``ref`` under
the assumption that the fieldview code stores its results in
them.
domain: If given will be passed to the fieldview code as ``domain=``
keyword argument.
offset_provider: An override for the test case's offset_provider.
Use with care!
comparison: A comparison function, which will be called as
Expand All @@ -414,10 +417,13 @@ def verify(
used as an argument to the fieldview program and compared against ``ref``.
Else, ``inout`` will not be passed and compared to ``ref``.
"""
kwargs = {}
if out:
run(case, fieldview_prog, *args, out=out, offset_provider=offset_provider)
else:
run(case, fieldview_prog, *args, offset_provider=offset_provider)
kwargs["out"] = out
if domain:
kwargs["domain"] = domain

run(case, fieldview_prog, *args, **kwargs, offset_provider=offset_provider)

out_comp = out or inout
assert out_comp is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import pytest

from gt4py.next import errors
from gt4py.next import errors, common, constructors
from gt4py.next.ffront.decorator import field_operator, program, scan_operator
from gt4py.next.ffront.fbuiltins import broadcast, int32

Expand Down Expand Up @@ -296,3 +296,21 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te
)
is not None
)


@pytest.mark.uses_origin
def test_direct_fo_call_with_domain_arg(cartesian_case):
@field_operator
def testee(inp: IField) -> IField:
return inp

size = cartesian_case.default_sizes[IDim]
inp = cases.allocate(cartesian_case, testee, "inp").unique()()
out = cases.allocate(
cartesian_case, testee, cases.RETURN, strategy=cases.ConstInitializer(42)
)()
ref = inp.array_ns.zeros(size)
ref[0] = ref[-1] = 42
ref[1:-1] = inp.ndarray[1:-1]

cases.verify(cartesian_case, testee, inp, out=out, domain={IDim: (1, size - 1)}, ref=ref)
Loading