Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor advection stencils to use StencilTest #336

Merged
merged 27 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0dbb5ea
Refactor a few advection stencils
samkellerhals Dec 13, 2023
0fd8ea4
Disable C2V offset provider due to missing connectivity
samkellerhals Dec 14, 2023
3a5f2e5
Update with main
samkellerhals Jan 8, 2024
a8a622f
Allow comparing subset of output field
samkellerhals Jan 8, 2024
c983e7c
port btraj_dreg_03
samkellerhals Jan 8, 2024
7a028fc
Add face_val_ppm_stencil_02
samkellerhals Jan 8, 2024
00bc933
Add face_val_ppm_02a
samkellerhals Jan 8, 2024
db284fe
Add Output helper class
samkellerhals Jan 9, 2024
ebcee6a
Add more stencils and reshape helper
samkellerhals Jan 9, 2024
d89a007
Fix test
samkellerhals Jan 9, 2024
d0b14dc
Add hflx limiter 2
samkellerhals Jan 10, 2024
afaba3f
Add hflx limiter 4
samkellerhals Jan 10, 2024
a810ab7
Add hflx limiter pd 1
samkellerhals Jan 10, 2024
62b543f
add horadv stencil 1
samkellerhals Jan 11, 2024
e786e21
rbf intp edge stencil 1
samkellerhals Jan 11, 2024
37b1c4d
add zero stencils
samkellerhals Jan 11, 2024
f150408
add upwind_vflux_ppm 1
samkellerhals Jan 11, 2024
7c817b7
TestVertAdvStencil01
samkellerhals Jan 11, 2024
b82d9c7
Add vlimit prbl sm stencils
samkellerhals Jan 11, 2024
90db4d8
Add missing assert in test
samkellerhals Jan 11, 2024
057d9c2
Fix error in numpy test_hflx_limiter_mo_stencil_03
Jan 12, 2024
b154bc1
Fix missing numpy ref. stencils calls
Jan 12, 2024
ffb6b79
refactor hflx limiter mo stencil 03
samkellerhals Jan 16, 2024
d971c05
hflx limiter pd stencil 02
samkellerhals Jan 16, 2024
fab8722
Fix test
samkellerhals Jan 16, 2024
c6cb54d
Use GridType
samkellerhals Jan 17, 2024
de88c3f
Slice p_face
samkellerhals Jan 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field, broadcast, where

Expand All @@ -30,7 +30,7 @@ def _btraj_dreg_stencil_01(
return lvn_sys_pos


@program
@program(grid_type=GridType.UNSTRUCTURED)
def btraj_dreg_stencil_01(
lcounterclock: bool,
p_vn: Field[[EdgeDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field, broadcast, int32, sqrt, where

Expand All @@ -34,7 +34,7 @@ def _btraj_dreg_stencil_02(
return opt_famask_dsl


@program
@program(grid_type=GridType.UNSTRUCTURED)
def btraj_dreg_stencil_02(
p_vn: Field[[EdgeDim, KDim], float],
p_vt: Field[[EdgeDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field, int32, where

Expand Down Expand Up @@ -102,7 +102,7 @@ def _btraj_dreg_stencil_03(
)


@program
@program(grid_type=GridType.UNSTRUCTURED)
def btraj_dreg_stencil_03(
p_vn: Field[[EdgeDim, KDim], float],
p_vt: Field[[EdgeDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field, broadcast, int32, where

Expand Down Expand Up @@ -70,7 +70,7 @@ def _face_val_ppm_stencil_01(
return z_slope


@program
@program(grid_type=GridType.UNSTRUCTURED)
def face_val_ppm_stencil_01(
p_cc: Field[[CellDim, KDim], float],
p_cellhgt_mc_now: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field, broadcast, int32, where

Expand Down Expand Up @@ -71,7 +71,7 @@ def _face_val_ppm_stencil_02(
return p_face


@program
@program(grid_type=GridType.UNSTRUCTURED)
def face_val_ppm_stencil_02(
p_cc: Field[[CellDim, KDim], float],
p_cellhgt_mc_now: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field

Expand All @@ -29,7 +29,7 @@ def _face_val_ppm_stencil_02a(
return p_face


@program
@program(grid_type=GridType.UNSTRUCTURED)
def face_val_ppm_stencil_02a(
p_cc: Field[[CellDim, KDim], float],
p_cellhgt_mc_now: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field

Expand All @@ -25,7 +25,7 @@ def _face_val_ppm_stencil_02b(
return p_face


@program
@program(grid_type=GridType.UNSTRUCTURED)
def face_val_ppm_stencil_02b(
p_cc: Field[[CellDim, KDim], float],
p_face: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field

Expand All @@ -25,7 +25,7 @@ def _face_val_ppm_stencil_02c(
return p_face


@program
@program(grid_type=GridType.UNSTRUCTURED)
def face_val_ppm_stencil_02c(
p_cc: Field[[CellDim, KDim], float],
p_face: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field

Expand Down Expand Up @@ -51,7 +51,7 @@ def _face_val_ppm_stencil_05(
return p_face


@program
@program(grid_type=GridType.UNSTRUCTURED)
def face_val_ppm_stencil_05(
p_cc: Field[[CellDim, KDim], float],
p_cellhgt_mc_now: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next.common import Field
from gt4py.next.common import Field, GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import int32, maximum, minimum, where

Expand Down Expand Up @@ -45,7 +45,7 @@ def _hflx_limiter_mo_stencil_02(
return (z_tracer_new_out, z_tracer_max_out, z_tracer_min_out)


@program
@program(grid_type=GridType.UNSTRUCTURED)
def hflx_limiter_mo_stencil_02(
refin_ctrl: Field[[CellDim], int32],
p_cc: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
def _hflx_limiter_pd_stencil_02(
refin_ctrl: Field[[EdgeDim], int32],
r_m: Field[[CellDim, KDim], float],
p_mflx_tracer_h_in: Field[[EdgeDim, KDim], float],
p_mflx_tracer_h: Field[[EdgeDim, KDim], float],
bound: int32,
) -> Field[[EdgeDim, KDim], float]:
p_mflx_tracer_h_out = where(
refin_ctrl == bound,
p_mflx_tracer_h_in,
p_mflx_tracer_h,
where(
p_mflx_tracer_h_in >= 0.0,
p_mflx_tracer_h_in * r_m(E2C[0]),
p_mflx_tracer_h_in * r_m(E2C[1]),
p_mflx_tracer_h >= 0.0,
p_mflx_tracer_h * r_m(E2C[0]),
p_mflx_tracer_h * r_m(E2C[1]),
),
)
return p_mflx_tracer_h_out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next.common import Field
from gt4py.next.common import Field, GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import broadcast

Expand All @@ -23,6 +23,6 @@ def _set_zero_c() -> Field[[CellDim], float]:
return broadcast(0.0, (CellDim,))


@program
@program(grid_type=GridType.UNSTRUCTURED)
def set_zero_c(field: Field[[CellDim], float]):
_set_zero_c(out=field)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next.common import Field
from gt4py.next.common import Field, GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import broadcast

Expand All @@ -23,6 +23,6 @@ def _set_zero_c_k() -> Field[[CellDim, KDim], float]:
return broadcast(0.0, (CellDim, KDim))


@program
@program(grid_type=GridType.UNSTRUCTURED)
def set_zero_c_k(field: Field[[CellDim, KDim], float]):
_set_zero_c_k(out=field)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next.common import Field
from gt4py.next.common import Field, GridType
from gt4py.next.ffront.decorator import field_operator, program

from icon4py.model.common.dimension import CellDim, KDim
Expand All @@ -29,7 +29,7 @@ def _upwind_vflux_ppm_stencil_01(
return z_delta_q, z_a1


@program
@program(grid_type=GridType.UNSTRUCTURED)
def upwind_vflux_ppm_stencil_01(
z_face_up: Field[[CellDim, KDim], float],
z_face_low: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import ( # noqa: A004 # import gt4py builtin
Field,
Expand Down Expand Up @@ -39,7 +39,7 @@ def _v_limit_prbl_sm_stencil_01(
return l_limit


@program
@program(grid_type=GridType.UNSTRUCTURED)
def v_limit_prbl_sm_stencil_01(
p_face: Field[[CellDim, KDim], float],
p_cc: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field, FieldOffset, int32, minimum, where

Expand Down Expand Up @@ -43,7 +43,7 @@ def _v_limit_prbl_sm_stencil_02(
return q_face_up, q_face_low


@program
@program(grid_type=GridType.UNSTRUCTURED)
def v_limit_prbl_sm_stencil_02(
l_limit: Field[[CellDim, KDim], int32],
p_face: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import Field

Expand All @@ -35,7 +35,7 @@ def _vert_adv_stencil_01(
return tracer_new


@program
@program(grid_type=GridType.UNSTRUCTURED)
def vert_adv_stencil_01(
tracer_now: Field[[CellDim, KDim], float],
rhodz_now: Field[[CellDim, KDim], float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,54 +10,43 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np
import pytest

from icon4py.model.atmosphere.advection.btraj_dreg_stencil_01 import btraj_dreg_stencil_01
from icon4py.model.common.dimension import EdgeDim, KDim
from icon4py.model.common.grid.simple import SimpleGrid
from icon4py.model.common.test_utils.helpers import random_field, zero_field


def btraj_dreg_stencil_01_numpy(
lcounterclock: bool,
p_vn: np.array,
tangent_orientation: np.array,
):
tangent_orientation = np.expand_dims(tangent_orientation, axis=-1)

tangent_orientation = np.broadcast_to(tangent_orientation, p_vn.shape)

lvn_sys_pos_true = np.where(tangent_orientation * p_vn >= 0.0, True, False)

mask_lcounterclock = np.broadcast_to(lcounterclock, p_vn.shape)
from icon4py.model.common.test_utils.helpers import StencilTest, random_field, zero_field

lvn_sys_pos = np.where(mask_lcounterclock, lvn_sys_pos_true, False)

return lvn_sys_pos
class TestBtrajDregStencil01(StencilTest):
PROGRAM = btraj_dreg_stencil_01
OUTPUTS = ("lvn_sys_pos",)

@staticmethod
def reference(
grid, lcounterclock: bool, p_vn: np.array, tangent_orientation: np.array, **kwargs
):
tangent_orientation = np.expand_dims(tangent_orientation, axis=-1)

def test_btraj_dreg_stencil_01(backend):
grid = SimpleGrid()
lcounterclock = True
p_vn = random_field(grid, EdgeDim, KDim)
tangent_orientation = np.broadcast_to(tangent_orientation, p_vn.shape)

tangent_orientation = random_field(grid, EdgeDim)
lvn_sys_pos_true = np.where(tangent_orientation * p_vn >= 0.0, True, False)

lvn_sys_pos = zero_field(grid, EdgeDim, KDim, dtype=bool)
mask_lcounterclock = np.broadcast_to(lcounterclock, p_vn.shape)

ref = btraj_dreg_stencil_01_numpy(
lcounterclock,
p_vn.asnumpy(),
tangent_orientation.asnumpy(),
)
lvn_sys_pos = np.where(mask_lcounterclock, lvn_sys_pos_true, False)

btraj_dreg_stencil_01.with_backend(backend)(
lcounterclock,
p_vn,
tangent_orientation,
lvn_sys_pos,
offset_provider={},
)
return dict(lvn_sys_pos=lvn_sys_pos)

assert np.allclose(ref, lvn_sys_pos.asnumpy())
@pytest.fixture
def input_data(self, grid):
lcounterclock = True
p_vn = random_field(grid, EdgeDim, KDim)
tangent_orientation = random_field(grid, EdgeDim)
lvn_sys_pos = zero_field(grid, EdgeDim, KDim, dtype=bool)
return dict(
lcounterclock=lcounterclock,
p_vn=p_vn,
tangent_orientation=tangent_orientation,
lvn_sys_pos=lvn_sys_pos,
)
Loading