Skip to content

Commit

Permalink
Fix test dependencies on savepoints (#349)
Browse files Browse the repository at this point in the history
(cleanup) Make test_solve_nonhydro.py indenpendent of `VelocityInitSavepoint`.
  • Loading branch information
halungge authored Feb 6, 2024
1 parent 40268ea commit a5c8e18
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from gt4py.next.ffront.fbuiltins import broadcast, int32, minimum

from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim
from icon4py.model.common.math.smagorinsky import en_smag_fac_for_zero_nshift
from icon4py.model.common.math.smagorinsky import _en_smag_fac_for_zero_nshift


# TODO(Magdalena): fix duplication: duplicated from test testutils/utils.py
Expand Down Expand Up @@ -115,7 +115,7 @@ def _init_diffusion_local_fields_for_regular_timestemp(
) -> tuple[Field[[KDim], float], Field[[KDim], float], Field[[KDim], float]]:
diff_multfac_vn = _setup_runtime_diff_multfac_vn(k4, dyn_substeps)
smag_limit = _setup_smag_limit(diff_multfac_vn)
enh_smag_fac = en_smag_fac_for_zero_nshift(
enh_smag_fac = _en_smag_fac_for_zero_nshift(
vect_a,
hdiff_smag_fac,
hdiff_smag_fac2,
Expand Down
2 changes: 1 addition & 1 deletion model/atmosphere/diffusion/tests/diffusion_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def construct_diagnostics(
savepoint: IconDiffusionInitSavepoint,
grid_savepoint: IconGridSavepoint,
) -> DiffusionDiagnosticState:
grid = grid_savepoint.construct_icon_grid()
grid = grid_savepoint.construct_icon_grid(on_gpu=False)
dwdx = savepoint.dwdx() if savepoint.dwdx() else zero_field(grid, CellDim, KDim)
dwdy = savepoint.dwdy() if savepoint.dwdy() else zero_field(grid, CellDim, KDim)
return DiffusionDiagnosticState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def init(
self.config.divdamp_z2,
self.config.divdamp_z3,
self.config.divdamp_z4,
out=self.enh_divdamp_fac,
self.enh_divdamp_fac,
offset_provider={"Koff": KDim},
)

Expand Down
94 changes: 44 additions & 50 deletions model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.datatest_utils import GLOBAL_EXPERIMENT, REGIONAL_EXPERIMENT
from icon4py.model.common.test_utils.helpers import dallclose
from icon4py.model.common.test_utils.serialbox_utils import IconNonHydroInitSavepoint

from .utils import (
construct_config,
Expand Down Expand Up @@ -71,7 +72,7 @@ def test_validate_divdamp_fields_against_savepoint_values(
config.divdamp_z2,
config.divdamp_z3,
config.divdamp_z4,
out=enh_divdamp_fac,
enh_divdamp_fac,
offset_provider={"Koff": KDim},
)
_calculate_scal_divdamp.with_backend(backend)(
Expand Down Expand Up @@ -109,7 +110,6 @@ def test_nonhydro_predictor_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init,
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -123,15 +123,14 @@ def test_nonhydro_predictor_step(
sp_exit = savepoint_nonhydro_exit
nonhydro_params = NonHydrostaticParams(config)
vertical_params = create_vertical_params(damping_height, grid_savepoint)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
dtime = sp.get_metadata("dtime").get("dtime")
recompute = sp.get_metadata("recompute").get("recompute")
linit = sp.get_metadata("linit").get("linit")

nnow = 0
nnew = 1

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
Expand Down Expand Up @@ -442,29 +441,29 @@ def test_nonhydro_predictor_step(
assert dallclose(prognostic_state_nnew.theta_v.asnumpy(), sp_exit.theta_v_new().asnumpy())


def construct_diagnostics(sp, sp_v):
def construct_diagnostics(init_savepoint: IconNonHydroInitSavepoint):
return DiagnosticStateNonHydro(
theta_v_ic=sp.theta_v_ic(),
exner_pr=sp.exner_pr(),
rho_ic=sp.rho_ic(),
ddt_exner_phy=sp.ddt_exner_phy(),
grf_tend_rho=sp.grf_tend_rho(),
grf_tend_thv=sp.grf_tend_thv(),
grf_tend_w=sp.grf_tend_w(),
mass_fl_e=sp.mass_fl_e(),
ddt_vn_phy=sp.ddt_vn_phy(),
grf_tend_vn=sp.grf_tend_vn(),
ddt_vn_apc_ntl1=sp_v.ddt_vn_apc_pc(1),
ddt_vn_apc_ntl2=sp_v.ddt_vn_apc_pc(2),
ddt_w_adv_ntl1=sp_v.ddt_w_adv_pc(1),
ddt_w_adv_ntl2=sp_v.ddt_w_adv_pc(2),
vt=sp_v.vt(),
vn_ie=sp_v.vn_ie(),
w_concorr_c=sp_v.w_concorr_c(),
theta_v_ic=init_savepoint.theta_v_ic(),
exner_pr=init_savepoint.exner_pr(),
rho_ic=init_savepoint.rho_ic(),
ddt_exner_phy=init_savepoint.ddt_exner_phy(),
grf_tend_rho=init_savepoint.grf_tend_rho(),
grf_tend_thv=init_savepoint.grf_tend_thv(),
grf_tend_w=init_savepoint.grf_tend_w(),
mass_fl_e=init_savepoint.mass_fl_e(),
ddt_vn_phy=init_savepoint.ddt_vn_phy(),
grf_tend_vn=init_savepoint.grf_tend_vn(),
ddt_vn_apc_ntl1=init_savepoint.ddt_vn_apc_pc(1),
ddt_vn_apc_ntl2=init_savepoint.ddt_vn_apc_pc(2),
ddt_w_adv_ntl1=init_savepoint.ddt_w_adv_pc(1),
ddt_w_adv_ntl2=init_savepoint.ddt_w_adv_pc(2),
vt=init_savepoint.vt(),
vn_ie=init_savepoint.vn_ie(),
w_concorr_c=init_savepoint.w_concorr_c(),
rho_incr=None, # sp.rho_incr(),
vn_incr=None, # sp.vn_incr(),
exner_incr=None, # sp.exner_incr(),
exner_dyn_incr=sp.exner_dyn_incr(),
exner_dyn_incr=init_savepoint.exner_dyn_incr(),
)


Expand Down Expand Up @@ -496,7 +495,6 @@ def test_nonhydro_corrector_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init,
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -514,18 +512,17 @@ def test_nonhydro_corrector_step(
nflatlev=grid_savepoint.nflatlev(),
nflat_gradp=grid_savepoint.nflat_gradp(),
)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
clean_mflx = sp_v.get_metadata("clean_mflx").get("clean_mflx")
lprep_adv = sp_v.get_metadata("prep_adv").get("prep_adv")
dtime = sp.get_metadata("dtime").get("dtime")
clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
prep_adv = PrepAdvection(
vn_traj=sp.vn_traj(), mass_flx_me=sp.mass_flx_me(), mass_flx_ic=sp.mass_flx_ic()
)

nnow = 0 # TODO: @abishekg7 read from serialized data?
nnew = 1

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

z_fields = IntermediateFields(
z_gradh_exner=sp.z_gradh_exner(),
Expand All @@ -540,8 +537,8 @@ def test_nonhydro_corrector_step(
z_graddiv_vn=sp.z_graddiv_vn(),
z_rho_expl=sp.z_rho_expl(),
z_dwdz_dd=sp.z_dwdz_dd(),
z_kin_hor_e=sp_v.z_kin_hor_e(),
z_vt_ie=sp_v.z_vt_ie(),
z_kin_hor_e=sp.z_kin_hor_e(),
z_vt_ie=sp.z_vt_ie(),
)

divdamp_fac_o2 = sp.divdamp_fac_o2()
Expand Down Expand Up @@ -683,7 +680,6 @@ def test_run_solve_nonhydro_single_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init, # TODO (magdalena) this should not be needed in test_solve_nonhydro.py, only for test_velocity_advection.py
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -697,20 +693,19 @@ def test_run_solve_nonhydro_single_step(
sp_step_exit = savepoint_nonhydro_step_exit
nonhydro_params = NonHydrostaticParams(config)
vertical_params = create_vertical_params(damping_height, grid_savepoint)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
lprep_adv = sp_v.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp_v.get_metadata("clean_mflx").get("clean_mflx")
dtime = sp.get_metadata("dtime").get("dtime")
lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
prep_adv = PrepAdvection(
vn_traj=sp.vn_traj(), mass_flx_me=sp.mass_flx_me(), mass_flx_ic=sp.mass_flx_ic()
)

nnow = 0
nnew = 1
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
recompute = sp.get_metadata("recompute").get("recompute")
linit = sp.get_metadata("linit").get("linit")

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
Expand Down Expand Up @@ -797,7 +792,7 @@ def test_run_solve_nonhydro_multi_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init,
vn_only,
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -810,20 +805,19 @@ def test_run_solve_nonhydro_multi_step(
sp_step_exit = savepoint_nonhydro_step_exit
nonhydro_params = NonHydrostaticParams(config)
vertical_params = create_vertical_params(damping_height, grid_savepoint)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
lprep_adv = sp_v.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp_v.get_metadata("clean_mflx").get("clean_mflx")
dtime = sp.get_metadata("dtime").get("dtime")
lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
prep_adv = PrepAdvection(
vn_traj=sp.vn_traj(), mass_flx_me=sp.mass_flx_me(), mass_flx_ic=sp.mass_flx_ic()
)

nnow = 0
nnew = 1
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
diagnostic_state_nh = construct_diagnostics(sp, sp_v)
recompute = sp.get_metadata("recompute").get("recompute")
linit = sp.get_metadata("linit").get("linit")

diagnostic_state_nh = construct_diagnostics(sp)
prognostic_state_ls = create_prognostic_states(sp)

interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
Expand Down
31 changes: 29 additions & 2 deletions model/common/src/icon4py/model/common/math/smagorinsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next import Field, field_operator
from gt4py.next import Field, field_operator, program
from gt4py.next.ffront.fbuiltins import broadcast, maximum, minimum

from icon4py.model.common.dimension import KDim, Koff


@field_operator
def en_smag_fac_for_zero_nshift(
def _en_smag_fac_for_zero_nshift(
vect_a: Field[[KDim], float],
hdiff_smag_fac: float,
hdiff_smag_fac2: float,
Expand All @@ -45,3 +45,30 @@ def en_smag_fac_for_zero_nshift(
dzqdr = minimum(broadcast(dz42, (KDim,)), maximum(zero, zf - hdiff_smag_z2))
enh_smag_fac = hdiff_smag_fac + (dzlin * alin) + dzqdr * (aqdr + dzqdr * bqdr)
return enh_smag_fac


@program
def en_smag_fac_for_zero_nshift(
vect_a: Field[[KDim], float],
hdiff_smag_fac: float,
hdiff_smag_fac2: float,
hdiff_smag_fac3: float,
hdiff_smag_fac4: float,
hdiff_smag_z: float,
hdiff_smag_z2: float,
hdiff_smag_z3: float,
hdiff_smag_z4: float,
enh_smag_fac: Field[[KDim], float],
):
_en_smag_fac_for_zero_nshift(
vect_a,
hdiff_smag_fac,
hdiff_smag_fac2,
hdiff_smag_fac3,
hdiff_smag_fac4,
hdiff_smag_z,
hdiff_smag_z2,
hdiff_smag_z3,
hdiff_smag_z4,
out=enh_smag_fac,
)
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def icon_grid(grid_savepoint, experiment):
Uses the special grid_savepoint that contains data from p_patch
"""
return grid_savepoint.construct_icon_grid()
return grid_savepoint.construct_icon_grid(on_gpu=False)


@pytest.fixture
Expand Down
35 changes: 23 additions & 12 deletions model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,18 @@ def construct_prognostics(self) -> PrognosticState:


class IconNonHydroInitSavepoint(IconSavepoint):
def z_vt_ie(self):
return self._get_field("z_vt_ie", EdgeDim, KDim)

def z_kin_hor_e(self):
return self._get_field("z_kin_hor_e", EdgeDim, KDim)

def vn_ie(self):
return self._get_field("vn_ie", EdgeDim, KDim)

def vt(self):
return self._get_field("vt", EdgeDim, KDim)

def bdy_divdamp(self):
return self._get_field("bdy_divdamp", KDim)

Expand Down Expand Up @@ -732,21 +744,20 @@ def grf_tend_thv(self):
def grf_tend_vn(self):
return self._get_field("grf_tend_vn", EdgeDim, KDim)

def w_concorr_c(self):
return self._get_field("w_concorr_c", CellDim, KDim)

def ddt_vn_apc_pc(self, ntnd):
return self._get_field_component("ddt_vn_apc_pc", ntnd, (EdgeDim, KDim))

def ddt_w_adv_pc(self, ntnd):
return self._get_field_component("ddt_w_adv_ntl", ntnd, (CellDim, KDim))

def ddt_vn_adv_ntl(self, ntl):
buffer = np.squeeze(self.serializer.read("ddt_vn_adv_ntl", self.savepoint).astype(float))[
:, :, ntl - 1
]
dims = (EdgeDim, KDim)
buffer = self._reduce_to_dim_size(buffer, dims)
return as_field(dims, buffer)
return self._get_field_component("ddt_vn_apc_pc", ntl, (EdgeDim, KDim))

def ddt_w_adv_ntl(self, ntl):
buffer = np.squeeze(self.serializer.read("ddt_w_adv_ntl", self.savepoint).astype(float))[
:, :, ntl - 1
]
dims = (CellDim, KDim)
buffer = self._reduce_to_dim_size(buffer, dims)
return as_field(dims, buffer)
return self._get_field_component("ddt_w_adv_ntl", ntl, (CellDim, KDim))

def grf_tend_w(self):
return self._get_field("grf_tend_w", CellDim, KDim)
Expand Down
7 changes: 5 additions & 2 deletions model/common/tests/math_tests/test_smagorinsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np
from gt4py.next.program_processors.runners.roundtrip import backend as roundtrip

from icon4py.model.common.dimension import KDim
from icon4py.model.common.grid.simple import SimpleGrid
Expand All @@ -20,6 +21,8 @@
from icon4py.model.common.test_utils.reference_funcs import enhanced_smagorinski_factor_numpy


# TODO (magdalena) stencil does not run on embedded backend, broadcast(0.0, (KDim,)) return scalar?
# TODO (magdalena) run as to StencilTest
def test_init_enh_smag_fac():
grid = SimpleGrid()
enh_smag_fac = zero_field(grid, KDim)
Expand All @@ -28,11 +31,11 @@ def test_init_enh_smag_fac():
z = (0.1, 0.2, 0.3, 0.4)

enhanced_smag_fac_np = enhanced_smagorinski_factor_numpy(fac, z, a_vec.asnumpy())
en_smag_fac_for_zero_nshift(
en_smag_fac_for_zero_nshift.with_backend(roundtrip)(
a_vec,
*fac,
*z,
out=enh_smag_fac,
enh_smag_fac,
offset_provider={"Koff": KDim},
)
assert np.allclose(enhanced_smag_fac_np, enh_smag_fac.asnumpy())
Loading

0 comments on commit a5c8e18

Please sign in to comment.