Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test dependencies on savepoints #349

Merged
merged 12 commits into from
Feb 6, 2024
Merged
97 changes: 45 additions & 52 deletions model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.datatest_utils import GLOBAL_EXPERIMENT, REGIONAL_EXPERIMENT
from icon4py.model.common.test_utils.helpers import dallclose
from icon4py.model.common.test_utils.serialbox_utils import IconNonHydroInitSavepoint

from .utils import (
construct_config,
Expand Down Expand Up @@ -102,14 +103,12 @@ def test_validate_divdamp_fields_against_savepoint_values(
def test_nonhydro_predictor_step(
istep_init,
istep_exit,
jstep_init,
step_date_init,
step_date_exit,
icon_grid,
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init,
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -123,16 +122,15 @@ def test_nonhydro_predictor_step(
sp_exit = savepoint_nonhydro_exit
nonhydro_params = NonHydrostaticParams(config)
vertical_params = create_vertical_params(damping_height, grid_savepoint)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
recompute = sp_v.get_metadata("recompute").get("recompute")
dtime = sp.get_metadata("dtime").get("dtime")
recompute = sp.get_metadata("recompute").get("recompute")
dyn_timestep = sp.get_metadata("dyn_timestep").get("dyn_timestep")
linit = sp_v.get_metadata("linit").get("linit")
linit = sp.get_metadata("linit").get("linit")

nnow = 0
nnew = 1

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
Expand Down Expand Up @@ -443,29 +441,29 @@ def test_nonhydro_predictor_step(
assert dallclose(prognostic_state_nnew.theta_v.asnumpy(), sp_exit.theta_v_new().asnumpy())


def construct_diagnostics(sp, sp_v):
def construct_diagnostics(init_savepoint: IconNonHydroInitSavepoint):
return DiagnosticStateNonHydro(
theta_v_ic=sp.theta_v_ic(),
exner_pr=sp.exner_pr(),
rho_ic=sp.rho_ic(),
ddt_exner_phy=sp.ddt_exner_phy(),
grf_tend_rho=sp.grf_tend_rho(),
grf_tend_thv=sp.grf_tend_thv(),
grf_tend_w=sp.grf_tend_w(),
mass_fl_e=sp.mass_fl_e(),
ddt_vn_phy=sp.ddt_vn_phy(),
grf_tend_vn=sp.grf_tend_vn(),
ddt_vn_apc_ntl1=sp_v.ddt_vn_apc_pc(1),
ddt_vn_apc_ntl2=sp_v.ddt_vn_apc_pc(2),
ddt_w_adv_ntl1=sp_v.ddt_w_adv_pc(1),
ddt_w_adv_ntl2=sp_v.ddt_w_adv_pc(2),
vt=sp_v.vt(),
vn_ie=sp_v.vn_ie(),
w_concorr_c=sp_v.w_concorr_c(),
theta_v_ic=init_savepoint.theta_v_ic(),
exner_pr=init_savepoint.exner_pr(),
rho_ic=init_savepoint.rho_ic(),
ddt_exner_phy=init_savepoint.ddt_exner_phy(),
grf_tend_rho=init_savepoint.grf_tend_rho(),
grf_tend_thv=init_savepoint.grf_tend_thv(),
grf_tend_w=init_savepoint.grf_tend_w(),
mass_fl_e=init_savepoint.mass_fl_e(),
ddt_vn_phy=init_savepoint.ddt_vn_phy(),
grf_tend_vn=init_savepoint.grf_tend_vn(),
ddt_vn_apc_ntl1=init_savepoint.ddt_vn_apc_pc(1),
ddt_vn_apc_ntl2=init_savepoint.ddt_vn_apc_pc(2),
ddt_w_adv_ntl1=init_savepoint.ddt_w_adv_pc(1),
ddt_w_adv_ntl2=init_savepoint.ddt_w_adv_pc(2),
vt=init_savepoint.vt(),
vn_ie=init_savepoint.vn_ie(),
w_concorr_c=init_savepoint.w_concorr_c(),
rho_incr=None, # sp.rho_incr(),
vn_incr=None, # sp.vn_incr(),
exner_incr=None, # sp.exner_incr(),
exner_dyn_incr=sp.exner_dyn_incr(),
exner_dyn_incr=init_savepoint.exner_dyn_incr(),
)


Expand Down Expand Up @@ -496,7 +494,6 @@ def test_nonhydro_corrector_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init,
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -514,18 +511,17 @@ def test_nonhydro_corrector_step(
nflatlev=grid_savepoint.nflatlev(),
nflat_gradp=grid_savepoint.nflat_gradp(),
)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
clean_mflx = sp_v.get_metadata("clean_mflx").get("clean_mflx")
lprep_adv = sp_v.get_metadata("prep_adv").get("prep_adv")
dtime = sp.get_metadata("dtime").get("dtime")
clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
prep_adv = PrepAdvection(
vn_traj=sp.vn_traj(), mass_flx_me=sp.mass_flx_me(), mass_flx_ic=sp.mass_flx_ic()
)

nnow = 0 # TODO: @abishekg7 read from serialized data?
nnew = 1

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

z_fields = IntermediateFields(
z_gradh_exner=sp.z_gradh_exner(),
Expand All @@ -540,8 +536,8 @@ def test_nonhydro_corrector_step(
z_graddiv_vn=sp.z_graddiv_vn(),
z_rho_expl=sp.z_rho_expl(),
z_dwdz_dd=sp.z_dwdz_dd(),
z_kin_hor_e=sp_v.z_kin_hor_e(),
z_vt_ie=sp_v.z_vt_ie(),
z_kin_hor_e=sp.z_kin_hor_e(),
z_vt_ie=sp.z_vt_ie(),
)

divdamp_fac_o2 = sp.divdamp_fac_o2()
Expand Down Expand Up @@ -676,7 +672,6 @@ def test_run_solve_nonhydro_single_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init, # TODO (magdalena) this should not be needed in test_solve_nonhydro.py, only for test_velocity_advection.py
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -690,21 +685,20 @@ def test_run_solve_nonhydro_single_step(
sp_step_exit = savepoint_nonhydro_step_exit
nonhydro_params = NonHydrostaticParams(config)
vertical_params = create_vertical_params(damping_height, grid_savepoint)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
lprep_adv = sp_v.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp_v.get_metadata("clean_mflx").get("clean_mflx")
dtime = sp.get_metadata("dtime").get("dtime")
lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
prep_adv = PrepAdvection(
vn_traj=sp.vn_traj(), mass_flx_me=sp.mass_flx_me(), mass_flx_ic=sp.mass_flx_ic()
)

nnow = 0
nnew = 1
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
dyn_timestep = sp_v.get_metadata("dyn_timestep").get("dyn_timestep")
recompute = sp.get_metadata("recompute").get("recompute")
linit = sp.get_metadata("linit").get("linit")
dyn_timestep = sp.get_metadata("dyn_timestep").get("dyn_timestep")

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
Expand Down Expand Up @@ -784,7 +778,7 @@ def test_run_solve_nonhydro_multi_step(
savepoint_nonhydro_init,
damping_height,
grid_savepoint,
savepoint_velocity_init,
vn_only,
metrics_savepoint,
interpolation_savepoint,
savepoint_nonhydro_exit,
Expand All @@ -797,21 +791,20 @@ def test_run_solve_nonhydro_multi_step(
sp_step_exit = savepoint_nonhydro_step_exit
nonhydro_params = NonHydrostaticParams(config)
vertical_params = create_vertical_params(damping_height, grid_savepoint)
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
lprep_adv = sp_v.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp_v.get_metadata("clean_mflx").get("clean_mflx")
dtime = sp.get_metadata("dtime").get("dtime")
lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
prep_adv = PrepAdvection(
vn_traj=sp.vn_traj(), mass_flx_me=sp.mass_flx_me(), mass_flx_ic=sp.mass_flx_ic()
)

nnow = 0
nnew = 1
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
dyn_timestep = sp_v.get_metadata("dyn_timestep").get("dyn_timestep")
recompute = sp.get_metadata("recompute").get("recompute")
linit = sp.get_metadata("linit").get("linit")
dyn_timestep = sp.get_metadata("dyn_timestep").get("dyn_timestep")

diagnostic_state_nh = construct_diagnostics(sp, sp_v)
diagnostic_state_nh = construct_diagnostics(sp)

prognostic_state_ls = create_prognostic_states(sp)

Expand Down
2 changes: 2 additions & 0 deletions model/common/src/icon4py/model/common/grid/icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
C2E2CDim,
C2E2CODim,
C2EDim,
C2VDim,
CECDim,
CEDim,
CellDim,
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(self):
"E2C2V": (self._get_offset_provider, E2C2VDim, EdgeDim, VertexDim),
"V2E": (self._get_offset_provider, V2EDim, VertexDim, EdgeDim),
"V2C": (self._get_offset_provider, V2CDim, VertexDim, CellDim),
"C2V": (self._get_offset_provider, C2VDim, CellDim, VertexDim),
"E2ECV": (self._get_offset_provider_for_sparse_fields, E2C2VDim, EdgeDim, ECVDim),
"C2CEC": (self._get_offset_provider_for_sparse_fields, C2E2CDim, CellDim, CECDim),
"C2CE": (self._get_offset_provider_for_sparse_fields, C2EDim, CellDim, CEDim),
Expand Down
35 changes: 23 additions & 12 deletions model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,18 @@ def construct_prognostics(self) -> PrognosticState:


class IconNonHydroInitSavepoint(IconSavepoint):
def z_vt_ie(self):
return self._get_field("z_vt_ie", EdgeDim, KDim)

def z_kin_hor_e(self):
return self._get_field("z_kin_hor_e", EdgeDim, KDim)

def vn_ie(self):
return self._get_field("vn_ie", EdgeDim, KDim)

def vt(self):
return self._get_field("vt", EdgeDim, KDim)

def bdy_divdamp(self):
return self._get_field("bdy_divdamp", KDim)

Expand Down Expand Up @@ -729,21 +741,20 @@ def grf_tend_thv(self):
def grf_tend_vn(self):
return self._get_field("grf_tend_vn", EdgeDim, KDim)

def w_concorr_c(self):
return self._get_field("w_concorr_c", CellDim, KDim)

def ddt_vn_apc_pc(self, ntnd):
return self._get_field_component("ddt_vn_apc_pc", ntnd, (EdgeDim, KDim))

def ddt_w_adv_pc(self, ntnd):
return self._get_field_component("ddt_w_adv_ntl", ntnd, (CellDim, KDim))

def ddt_vn_adv_ntl(self, ntl):
buffer = np.squeeze(self.serializer.read("ddt_vn_adv_ntl", self.savepoint).astype(float))[
:, :, ntl - 1
]
dims = (EdgeDim, KDim)
buffer = self._reduce_to_dim_size(buffer, dims)
return as_field(dims, buffer)
return self._get_field_component("ddt_vn_apc_pc", ntl, (EdgeDim, KDim))

def ddt_w_adv_ntl(self, ntl):
buffer = np.squeeze(self.serializer.read("ddt_w_adv_ntl", self.savepoint).astype(float))[
:, :, ntl - 1
]
dims = (CellDim, KDim)
buffer = self._reduce_to_dim_size(buffer, dims)
return as_field(dims, buffer)
return self._get_field_component("ddt_w_adv_ntl", ntl, (CellDim, KDim))

def grf_tend_w(self):
return self._get_field("grf_tend_w", CellDim, KDim)
Expand Down