Skip to content

Commit

Permalink
refactor[next]: Use set_at & as_fieldop instead of closure in i…
Browse files Browse the repository at this point in the history
…terator tests (#1691)
  • Loading branch information
tehrengruber authored Dec 1, 2024
1 parent a26d91f commit 99c5300
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import gt4py.next as gtx
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset
from gt4py.next.program_processors.runners import double_roundtrip, roundtrip


Expand All @@ -27,16 +27,14 @@ def foo(inp):

@fendef(offset_provider={"I": I_loc, "J": J_loc})
def fencil(output, input):
closure(
cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input]
)
domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1))
set_at(as_fieldop(foo, domain)(input), domain, output)


@fendef(offset_provider={"I": J_loc, "J": I_loc})
def fencil_swapped(output, input):
closure(
cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input]
)
domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1))
set_at(as_fieldop(foo, domain)(input), domain, output)


def test_cartesian_offset_provider():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import gt4py.next as gtx
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fendef, fundef
from gt4py.next.iterator.runtime import set_at, fendef, fundef

from next_tests.unit_tests.conftest import program_processor, run_processor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import pytest

import gt4py.next as gtx
from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain, as_fieldop
from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset

from next_tests.unit_tests.conftest import program_processor, run_processor
from gt4py.next.iterator.embedded import StridedConnectivityField
Expand All @@ -36,7 +36,8 @@ def foo(inp):

@fendef(offset_provider={"O": LocA2LocAB_offset_provider})
def fencil(size, out, inp):
closure(unstructured_domain(named_range(LocA, 0, size)), foo, out, [inp])
domain = unstructured_domain(named_range(LocA, 0, size))
set_at(as_fieldop(foo, domain)(inp), domain, out)


@pytest.mark.uses_strided_neighbor_offset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import gt4py.next as gtx
from gt4py.next.iterator import transforms
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset

from next_tests.integration_tests.cases import IDim, JDim, KDim
from next_tests.unit_tests.conftest import program_processor, run_processor
Expand Down Expand Up @@ -94,12 +94,8 @@ def test_shifted_arg_to_lift(program_processor):

@fendef
def fen_direct_deref(i_size, j_size, out, inp):
closure(
cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)),
deref,
out,
[inp],
)
domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size))
set_at(as_fieldop(deref, domain)(inp), domain, out)


def test_direct_deref(program_processor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import gt4py.next as gtx
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fendef, fundef
from gt4py.next.iterator.runtime import set_at, fendef, fundef

from next_tests.unit_tests.conftest import program_processor, run_processor

Expand Down Expand Up @@ -114,16 +114,10 @@ def test_tuple_of_field_output_constructed_inside(program_processor, stencil):

@fendef
def fencil(size0, size1, size2, inp1, inp2, out1, out2):
closure(
cartesian_domain(
named_range(IDim, 0, size0),
named_range(JDim, 0, size1),
named_range(KDim, 0, size2),
),
stencil,
make_tuple(out1, out2),
[inp1, inp2],
domain = cartesian_domain(
named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2)
)
set_at(as_fieldop(stencil, domain)(inp1, inp2), domain, make_tuple(out1, out2))

shape = [5, 7, 9]
rng = np.random.default_rng()
Expand Down Expand Up @@ -159,15 +153,13 @@ def stencil(inp1, inp2, inp3):

@fendef
def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3):
closure(
cartesian_domain(
named_range(IDim, 0, size0),
named_range(JDim, 0, size1),
named_range(KDim, 0, size2),
),
stencil,
domain = cartesian_domain(
named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2)
)
set_at(
as_fieldop(stencil, domain)(inp1, inp2, inp3),
domain,
make_tuple(make_tuple(out1, out2), out3),
[inp1, inp2, inp3],
)

shape = [5, 7, 9]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@
import pytest

import gt4py.next as gtx
from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
from gt4py.next.iterator.builtins import (
cartesian_domain,
deref,
lift,
named_range,
shift,
as_fieldop,
)
from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset
from gt4py.next.program_processors.runners import gtfn

from next_tests.unit_tests.conftest import program_processor, run_processor
Expand Down Expand Up @@ -85,14 +92,10 @@ def test_anton_toy(stencil, program_processor):

@fendef(offset_provider={"i": IDim, "j": JDim})
def fencil(x, y, z, out, inp):
closure(
cartesian_domain(
named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z)
),
stencil,
out,
[inp],
domain = cartesian_domain(
named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z)
)
set_at(as_fieldop(stencil, domain)(inp), domain, out)

shape = [5, 7, 9]
rng = np.random.default_rng()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
reduce,
tuple_get,
unstructured_domain,
as_fieldop,
)
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset

from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import (
assert_close,
Expand All @@ -55,7 +56,8 @@ def compute_zavgS(pp, S_M):

@fendef
def compute_zavgS_fencil(n_edges, out, pp, S_M):
closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS, out, [pp, S_M])
domain = unstructured_domain(named_range(Edge, 0, n_edges))
set_at(as_fieldop(compute_zavgS, domain)(pp, S_M), domain, out)


@fundef
Expand Down Expand Up @@ -100,12 +102,8 @@ def compute_pnabla2(pp, S_M, sign, vol):

@fendef
def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol):
closure(
unstructured_domain(named_range(Vertex, 0, n_nodes)),
pnabla,
out,
[pp, S_MXX, S_MYY, sign, vol],
)
domain = unstructured_domain(named_range(Vertex, 0, n_nodes))
set_at(as_fieldop(pnabla, domain)(pp, S_MXX, S_MYY, sign, vol), domain, out)


@pytest.mark.requires_atlas
Expand Down Expand Up @@ -145,7 +143,8 @@ def test_compute_zavgS(program_processor):

@fendef
def compute_zavgS2_fencil(n_edges, out, pp, S_M):
closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS2, out, [pp, S_M])
domain = unstructured_domain(named_range(Edge, 0, n_edges))
set_at(as_fieldop(compute_zavgS2, domain)(pp, S_M), domain, out)


@pytest.mark.requires_atlas
Expand Down Expand Up @@ -212,12 +211,8 @@ def test_nabla(program_processor):

@fendef
def nabla2(n_nodes, out, pp, S, sign, vol):
closure(
unstructured_domain(named_range(Vertex, 0, n_nodes)),
compute_pnabla2,
out,
[pp, S, sign, vol],
)
domain = unstructured_domain(named_range(Vertex, 0, n_nodes))
set_at(as_fieldop(compute_pnabla2, domain)(pp, S, sign, vol), domain, out)


@pytest.mark.requires_atlas
Expand Down Expand Up @@ -276,17 +271,16 @@ def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge):
@fendef
def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge):
# TODO replace by single stencil which returns tuple
closure(
unstructured_domain(named_range(Vertex, 0, n_nodes)),
compute_pnabla_sign,
domain = unstructured_domain(named_range(Vertex, 0, n_nodes))
set_at(
as_fieldop(compute_pnabla_sign, domain)(pp, S_MXX, vol, node_index, is_pole_edge),
domain,
out_MXX,
[pp, S_MXX, vol, node_index, is_pole_edge],
)
closure(
unstructured_domain(named_range(Vertex, 0, n_nodes)),
compute_pnabla_sign,
set_at(
as_fieldop(compute_pnabla_sign, domain)(pp, S_MYY, vol, node_index, is_pole_edge),
domain,
out_MYY,
[pp, S_MYY, vol, node_index, is_pole_edge],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import gt4py.next as gtx
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset
from gt4py.next.program_processors.runners import gtfn

from next_tests.integration_tests.cases import IDim, JDim
Expand Down Expand Up @@ -57,12 +57,8 @@ def hdiff_sten(inp, coeff):

@fendef(offset_provider={"I": IDim, "J": JDim})
def hdiff(inp, coeff, out, x, y):
closure(
cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)),
hdiff_sten,
out,
[inp, coeff],
)
domain = cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y))
set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out)


@pytest.mark.uses_origin
Expand Down

0 comments on commit 99c5300

Please sign in to comment.