Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test dependencies on savepoints #349

Merged
merged 12 commits into from
Feb 6, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from gt4py.next.ffront.fbuiltins import broadcast, int32, minimum

from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim
from icon4py.model.common.math.smagorinsky import en_smag_fac_for_zero_nshift
from icon4py.model.common.math.smagorinsky import _en_smag_fac_for_zero_nshift


# TODO(Magdalena): fix duplication: duplicated from test testutils/utils.py
Expand Down Expand Up @@ -115,7 +115,7 @@ def _init_diffusion_local_fields_for_regular_timestemp(
) -> tuple[Field[[KDim], float], Field[[KDim], float], Field[[KDim], float]]:
diff_multfac_vn = _setup_runtime_diff_multfac_vn(k4, dyn_substeps)
smag_limit = _setup_smag_limit(diff_multfac_vn)
enh_smag_fac = en_smag_fac_for_zero_nshift(
enh_smag_fac = _en_smag_fac_for_zero_nshift(
vect_a,
hdiff_smag_fac,
hdiff_smag_fac2,
Expand Down
2 changes: 1 addition & 1 deletion model/atmosphere/diffusion/tests/diffusion_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def construct_diagnostics(
savepoint: IconDiffusionInitSavepoint,
grid_savepoint: IconGridSavepoint,
) -> DiffusionDiagnosticState:
grid = grid_savepoint.construct_icon_grid()
grid = grid_savepoint.construct_icon_grid(on_gpu=False)
dwdx = savepoint.dwdx() if savepoint.dwdx() else zero_field(grid, CellDim, KDim)
dwdy = savepoint.dwdy() if savepoint.dwdy() else zero_field(grid, CellDim, KDim)
return DiffusionDiagnosticState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def init(
self.config.divdamp_z2,
self.config.divdamp_z3,
self.config.divdamp_z4,
out=self.enh_divdamp_fac,
self.enh_divdamp_fac,
offset_provider={"Koff": KDim},
)

Expand Down
94 changes: 44 additions & 50 deletions model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.datatest_utils import GLOBAL_EXPERIMENT, REGIONAL_EXPERIMENT
from icon4py.model.common.test_utils.helpers import dallclose
from icon4py.model.common.test_utils.serialbox_utils import IconNonHydroInitSavepoint

from .utils import (
construct_config,
Expand Down Expand Up @@ -71,7 +72,7 @@ def test_validate_divdamp_fields_against_savepoint_values(
config.divdamp_z2,
config.divdamp_z3,
config.divdamp_z4,
out=enh_divdamp_fac,
enh_divdamp_fac,
offset_provider={"Koff": KDim},
)
_calculate_scal_divdamp.with_backend(backend)(
Expand Down Expand Up @@ -109,7 +110,6 @@ def test_nonhydro_predictor_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init,
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -123,15 +123,14 @@ def test_nonhydro_predictor_step(
sp_exit = savepoint_nonhydro_exit
nonhydro_params = NonHydrostaticParams(config)
vertical_params = create_vertical_params(damping_height, grid_savepoint)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
dtime = sp.get_metadata("dtime").get("dtime")
recompute = sp.get_metadata("recompute").get("recompute")
linit = sp.get_metadata("linit").get("linit")

nnow = 0
nnew = 1

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
Expand Down Expand Up @@ -442,29 +441,29 @@ def test_nonhydro_predictor_step(
assert dallclose(prognostic_state_nnew.theta_v.asnumpy(), sp_exit.theta_v_new().asnumpy())


def construct_diagnostics(sp, sp_v):
def construct_diagnostics(init_savepoint: IconNonHydroInitSavepoint):
return DiagnosticStateNonHydro(
theta_v_ic=sp.theta_v_ic(),
exner_pr=sp.exner_pr(),
rho_ic=sp.rho_ic(),
ddt_exner_phy=sp.ddt_exner_phy(),
grf_tend_rho=sp.grf_tend_rho(),
grf_tend_thv=sp.grf_tend_thv(),
grf_tend_w=sp.grf_tend_w(),
mass_fl_e=sp.mass_fl_e(),
ddt_vn_phy=sp.ddt_vn_phy(),
grf_tend_vn=sp.grf_tend_vn(),
ddt_vn_apc_ntl1=sp_v.ddt_vn_apc_pc(1),
ddt_vn_apc_ntl2=sp_v.ddt_vn_apc_pc(2),
ddt_w_adv_ntl1=sp_v.ddt_w_adv_pc(1),
ddt_w_adv_ntl2=sp_v.ddt_w_adv_pc(2),
vt=sp_v.vt(),
vn_ie=sp_v.vn_ie(),
w_concorr_c=sp_v.w_concorr_c(),
theta_v_ic=init_savepoint.theta_v_ic(),
exner_pr=init_savepoint.exner_pr(),
rho_ic=init_savepoint.rho_ic(),
ddt_exner_phy=init_savepoint.ddt_exner_phy(),
grf_tend_rho=init_savepoint.grf_tend_rho(),
grf_tend_thv=init_savepoint.grf_tend_thv(),
grf_tend_w=init_savepoint.grf_tend_w(),
mass_fl_e=init_savepoint.mass_fl_e(),
ddt_vn_phy=init_savepoint.ddt_vn_phy(),
grf_tend_vn=init_savepoint.grf_tend_vn(),
ddt_vn_apc_ntl1=init_savepoint.ddt_vn_apc_pc(1),
ddt_vn_apc_ntl2=init_savepoint.ddt_vn_apc_pc(2),
ddt_w_adv_ntl1=init_savepoint.ddt_w_adv_pc(1),
ddt_w_adv_ntl2=init_savepoint.ddt_w_adv_pc(2),
vt=init_savepoint.vt(),
vn_ie=init_savepoint.vn_ie(),
w_concorr_c=init_savepoint.w_concorr_c(),
rho_incr=None, # sp.rho_incr(),
vn_incr=None, # sp.vn_incr(),
exner_incr=None, # sp.exner_incr(),
exner_dyn_incr=sp.exner_dyn_incr(),
exner_dyn_incr=init_savepoint.exner_dyn_incr(),
)


Expand Down Expand Up @@ -496,7 +495,6 @@ def test_nonhydro_corrector_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init,
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -514,18 +512,17 @@ def test_nonhydro_corrector_step(
nflatlev=grid_savepoint.nflatlev(),
nflat_gradp=grid_savepoint.nflat_gradp(),
)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
clean_mflx = sp_v.get_metadata("clean_mflx").get("clean_mflx")
lprep_adv = sp_v.get_metadata("prep_adv").get("prep_adv")
dtime = sp.get_metadata("dtime").get("dtime")
clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
prep_adv = PrepAdvection(
vn_traj=sp.vn_traj(), mass_flx_me=sp.mass_flx_me(), mass_flx_ic=sp.mass_flx_ic()
)

nnow = 0 # TODO: @abishekg7 read from serialized data?
nnew = 1

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

z_fields = IntermediateFields(
z_gradh_exner=sp.z_gradh_exner(),
Expand All @@ -540,8 +537,8 @@ def test_nonhydro_corrector_step(
z_graddiv_vn=sp.z_graddiv_vn(),
z_rho_expl=sp.z_rho_expl(),
z_dwdz_dd=sp.z_dwdz_dd(),
z_kin_hor_e=sp_v.z_kin_hor_e(),
z_vt_ie=sp_v.z_vt_ie(),
z_kin_hor_e=sp.z_kin_hor_e(),
z_vt_ie=sp.z_vt_ie(),
)

divdamp_fac_o2 = sp.divdamp_fac_o2()
Expand Down Expand Up @@ -683,7 +680,6 @@ def test_run_solve_nonhydro_single_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init, # TODO (magdalena) this should not be needed in test_solve_nonhydro.py, only for test_velocity_advection.py
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -697,20 +693,19 @@ def test_run_solve_nonhydro_single_step(
sp_step_exit = savepoint_nonhydro_step_exit
nonhydro_params = NonHydrostaticParams(config)
vertical_params = create_vertical_params(damping_height, grid_savepoint)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
lprep_adv = sp_v.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp_v.get_metadata("clean_mflx").get("clean_mflx")
dtime = sp.get_metadata("dtime").get("dtime")
lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
prep_adv = PrepAdvection(
vn_traj=sp.vn_traj(), mass_flx_me=sp.mass_flx_me(), mass_flx_ic=sp.mass_flx_ic()
)

nnow = 0
nnew = 1
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
recompute = sp.get_metadata("recompute").get("recompute")
linit = sp.get_metadata("linit").get("linit")

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
Expand Down Expand Up @@ -797,7 +792,7 @@ def test_run_solve_nonhydro_multi_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init,
vn_only,
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -810,20 +805,19 @@ def test_run_solve_nonhydro_multi_step(
sp_step_exit = savepoint_nonhydro_step_exit
nonhydro_params = NonHydrostaticParams(config)
vertical_params = create_vertical_params(damping_height, grid_savepoint)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
lprep_adv = sp_v.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp_v.get_metadata("clean_mflx").get("clean_mflx")
dtime = sp.get_metadata("dtime").get("dtime")
lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
prep_adv = PrepAdvection(
vn_traj=sp.vn_traj(), mass_flx_me=sp.mass_flx_me(), mass_flx_ic=sp.mass_flx_ic()
)

nnow = 0
nnew = 1
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
diagnostic_state_nh = construct_diagnostics(sp, sp_v)
recompute = sp.get_metadata("recompute").get("recompute")
linit = sp.get_metadata("linit").get("linit")

diagnostic_state_nh = construct_diagnostics(sp)
prognostic_state_ls = create_prognostic_states(sp)

interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
Expand Down
31 changes: 29 additions & 2 deletions model/common/src/icon4py/model/common/math/smagorinsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import Field, field_operator
from gt4py.next import Field, field_operator, program
from gt4py.next.ffront.fbuiltins import broadcast, maximum, minimum

from icon4py.model.common.dimension import KDim, Koff


@field_operator
def en_smag_fac_for_zero_nshift(
def _en_smag_fac_for_zero_nshift(
vect_a: Field[[KDim], float],
hdiff_smag_fac: float,
hdiff_smag_fac2: float,
Expand All @@ -45,3 +45,30 @@ def en_smag_fac_for_zero_nshift(
dzqdr = minimum(broadcast(dz42, (KDim,)), maximum(zero, zf - hdiff_smag_z2))
enh_smag_fac = hdiff_smag_fac + (dzlin * alin) + dzqdr * (aqdr + dzqdr * bqdr)
return enh_smag_fac


@program
def en_smag_fac_for_zero_nshift(
vect_a: Field[[KDim], float],
hdiff_smag_fac: float,
hdiff_smag_fac2: float,
hdiff_smag_fac3: float,
hdiff_smag_fac4: float,
hdiff_smag_z: float,
hdiff_smag_z2: float,
hdiff_smag_z3: float,
hdiff_smag_z4: float,
enh_smag_fac: Field[[KDim], float],
):
_en_smag_fac_for_zero_nshift(
vect_a,
hdiff_smag_fac,
hdiff_smag_fac2,
hdiff_smag_fac3,
hdiff_smag_fac4,
hdiff_smag_z,
hdiff_smag_z2,
hdiff_smag_z3,
hdiff_smag_z4,
out=enh_smag_fac,
)
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def icon_grid(grid_savepoint, experiment):

Uses the special grid_savepoint that contains data from p_patch
"""
return grid_savepoint.construct_icon_grid()
return grid_savepoint.construct_icon_grid(on_gpu=False)


@pytest.fixture
Expand Down
35 changes: 23 additions & 12 deletions model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,18 @@ def construct_prognostics(self) -> PrognosticState:


class IconNonHydroInitSavepoint(IconSavepoint):
def z_vt_ie(self):
return self._get_field("z_vt_ie", EdgeDim, KDim)

def z_kin_hor_e(self):
return self._get_field("z_kin_hor_e", EdgeDim, KDim)

def vn_ie(self):
return self._get_field("vn_ie", EdgeDim, KDim)

def vt(self):
return self._get_field("vt", EdgeDim, KDim)

def bdy_divdamp(self):
return self._get_field("bdy_divdamp", KDim)

Expand Down Expand Up @@ -732,21 +744,20 @@ def grf_tend_thv(self):
def grf_tend_vn(self):
return self._get_field("grf_tend_vn", EdgeDim, KDim)

def w_concorr_c(self):
return self._get_field("w_concorr_c", CellDim, KDim)

def ddt_vn_apc_pc(self, ntnd):
return self._get_field_component("ddt_vn_apc_pc", ntnd, (EdgeDim, KDim))

def ddt_w_adv_pc(self, ntnd):
return self._get_field_component("ddt_w_adv_ntl", ntnd, (CellDim, KDim))

def ddt_vn_adv_ntl(self, ntl):
buffer = np.squeeze(self.serializer.read("ddt_vn_adv_ntl", self.savepoint).astype(float))[
:, :, ntl - 1
]
dims = (EdgeDim, KDim)
buffer = self._reduce_to_dim_size(buffer, dims)
return as_field(dims, buffer)
return self._get_field_component("ddt_vn_apc_pc", ntl, (EdgeDim, KDim))

def ddt_w_adv_ntl(self, ntl):
buffer = np.squeeze(self.serializer.read("ddt_w_adv_ntl", self.savepoint).astype(float))[
:, :, ntl - 1
]
dims = (CellDim, KDim)
buffer = self._reduce_to_dim_size(buffer, dims)
return as_field(dims, buffer)
return self._get_field_component("ddt_w_adv_ntl", ntl, (CellDim, KDim))

def grf_tend_w(self):
return self._get_field("grf_tend_w", CellDim, KDim)
Expand Down
7 changes: 5 additions & 2 deletions model/common/tests/math_tests/test_smagorinsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np
from gt4py.next.program_processors.runners.roundtrip import backend as roundtrip

from icon4py.model.common.dimension import KDim
from icon4py.model.common.grid.simple import SimpleGrid
Expand All @@ -20,6 +21,8 @@
from icon4py.model.common.test_utils.reference_funcs import enhanced_smagorinski_factor_numpy


# TODO (magdalena) stencil does not run on embedded backend, broadcast(0.0, (KDim,)) return scalar?
# TODO (magdalena) run as to StencilTest
def test_init_enh_smag_fac():
grid = SimpleGrid()
enh_smag_fac = zero_field(grid, KDim)
Expand All @@ -28,11 +31,11 @@ def test_init_enh_smag_fac():
z = (0.1, 0.2, 0.3, 0.4)

enhanced_smag_fac_np = enhanced_smagorinski_factor_numpy(fac, z, a_vec.asnumpy())
en_smag_fac_for_zero_nshift(
en_smag_fac_for_zero_nshift.with_backend(roundtrip)(
a_vec,
*fac,
*z,
out=enh_smag_fac,
enh_smag_fac,
offset_provider={"Koff": KDim},
)
assert np.allclose(enhanced_smag_fac_np, enh_smag_fac.asnumpy())
Loading
Loading