From 342968f67944af54ff6fb808f4d4e0d97dedffa8 Mon Sep 17 00:00:00 2001
From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com>
Date: Fri, 3 May 2024 11:44:38 +0200
Subject: [PATCH] Metrics fields second batch (#454)

* further interpolation coefficients
---
 .../mo_intp_rbf_rbf_vec_interpol_cell.py      |  61 +++++++++
 .../mo_intp_rbf_rbf_vec_interpol_vertex.py    |   8 +-
 .../model/common/metrics/metric_fields.py     |  60 ++++++--
 .../common/test_utils/serialbox_utils.py      |   3 +
 .../tests/metric_tests/test_metric_fields.py  | 128 +++++++++++++++---
 5 files changed, 228 insertions(+), 32 deletions(-)
 create mode 100644 model/common/src/icon4py/model/common/interpolation/stencils/mo_intp_rbf_rbf_vec_interpol_cell.py

diff --git a/model/common/src/icon4py/model/common/interpolation/stencils/mo_intp_rbf_rbf_vec_interpol_cell.py b/model/common/src/icon4py/model/common/interpolation/stencils/mo_intp_rbf_rbf_vec_interpol_cell.py
new file mode 100644
index 0000000000..ecb8dd1aa7
--- /dev/null
+++ b/model/common/src/icon4py/model/common/interpolation/stencils/mo_intp_rbf_rbf_vec_interpol_cell.py
@@ -0,0 +1,61 @@
+# ICON4Py - ICON inspired code in Python and GT4Py
+#
+# Copyright (c) 2022, ETH Zurich and MeteoSwiss
+# All rights reserved.
+#
+# This file is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from gt4py.next.common import GridType
+from gt4py.next.ffront.decorator import field_operator, program
+from gt4py.next.ffront.fbuiltins import Field, int32, neighbor_sum
+
+from icon4py.model.common.dimension import (
+    C2E,
+    C2EDim,
+    CellDim,
+    EdgeDim,
+    KDim,
+)
+from icon4py.model.common.settings import backend
+from icon4py.model.common.type_alias import wpfloat
+
+
+@field_operator
+def _mo_intp_rbf_rbf_vec_interpol_cell(
+    p_vn_in: Field[[EdgeDim, KDim], wpfloat],
+    ptr_coeff_1: Field[[CellDim, C2EDim], wpfloat],
+    ptr_coeff_2: Field[[CellDim, C2EDim], wpfloat],
+) -> tuple[Field[[CellDim, KDim], wpfloat], Field[[CellDim, KDim], wpfloat]]:
+    p_u_out = neighbor_sum(ptr_coeff_1 * p_vn_in(C2E), axis=C2EDim)
+    p_v_out = neighbor_sum(ptr_coeff_2 * p_vn_in(C2E), axis=C2EDim)
+    return p_u_out, p_v_out
+
+
+@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
+def mo_intp_rbf_rbf_vec_interpol_cell(
+    p_vn_in: Field[[EdgeDim, KDim], wpfloat],
+    ptr_coeff_1: Field[[CellDim, C2EDim], wpfloat],
+    ptr_coeff_2: Field[[CellDim, C2EDim], wpfloat],
+    p_u_out: Field[[CellDim, KDim], wpfloat],
+    p_v_out: Field[[CellDim, KDim], wpfloat],
+    horizontal_start: int32,
+    horizontal_end: int32,
+    vertical_start: int32,
+    vertical_end: int32,
+):
+    _mo_intp_rbf_rbf_vec_interpol_cell(
+        p_vn_in,
+        ptr_coeff_1,
+        ptr_coeff_2,
+        out=(p_u_out, p_v_out),
+        domain={
+            CellDim: (horizontal_start, horizontal_end),
+            KDim: (vertical_start, vertical_end),
+        },
+    )
diff --git a/model/common/src/icon4py/model/common/interpolation/stencils/mo_intp_rbf_rbf_vec_interpol_vertex.py b/model/common/src/icon4py/model/common/interpolation/stencils/mo_intp_rbf_rbf_vec_interpol_vertex.py
index 2d867d0a93..6c74313020 100644
--- a/model/common/src/icon4py/model/common/interpolation/stencils/mo_intp_rbf_rbf_vec_interpol_vertex.py
+++ b/model/common/src/icon4py/model/common/interpolation/stencils/mo_intp_rbf_rbf_vec_interpol_vertex.py
@@ -15,7 +15,13 @@
 from gt4py.next.ffront.decorator import field_operator, program
 from gt4py.next.ffront.fbuiltins import Field, int32, neighbor_sum
 
-from icon4py.model.common.dimension import V2E, EdgeDim, KDim, V2EDim, VertexDim
+from icon4py.model.common.dimension import (
+    V2E,
+    EdgeDim,
+    KDim,
+    V2EDim,
+    VertexDim,
+)
 from icon4py.model.common.settings import backend
 from icon4py.model.common.type_alias import wpfloat
 
diff --git a/model/common/src/icon4py/model/common/metrics/metric_fields.py b/model/common/src/icon4py/model/common/metrics/metric_fields.py
index 7869306e6b..726510685d 100644
--- a/model/common/src/icon4py/model/common/metrics/metric_fields.py
+++ b/model/common/src/icon4py/model/common/metrics/metric_fields.py
@@ -25,12 +25,17 @@
     where,
 )
 
-from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, Koff, VertexDim
+from icon4py.model.common.dimension import (
+    CellDim,
+    EdgeDim,
+    KDim,
+    Koff,
+    VertexDim,
+)
 from icon4py.model.common.math.helpers import (
     _grad_fd_tang,
     average_cell_kdim_level_up,
     average_edge_kdim_level_up,
-    difference_k_level_down,
     difference_k_level_up,
     grad_fd_norm,
 )
@@ -79,12 +84,11 @@ def _compute_ddqz_z_half(
     z_mc: Field[[CellDim, KDim], wpfloat],
     k: Field[[KDim], int32],
     nlev: int32,
-) -> Field[[CellDim, KDim], wpfloat]:
-    ddqz_z_half = where(
-        (k > int32(0)) & (k < nlev),
-        difference_k_level_down(z_mc),
-        where(k == 0, 2.0 * (z_ifc - z_mc), 2.0 * (z_mc(Koff[-1]) - z_ifc)),
-    )
+):
+    # TODO: change this to concat_where once it's merged
+    ddqz_z_half = where(k == 0, 2.0 * (z_ifc - z_mc), 0.0)
+    ddqz_z_half = where((k > 0) & (k < nlev), z_mc(Koff[-1]) - z_mc, ddqz_z_half)
+    ddqz_z_half = where(k == nlev, 2.0 * (z_mc(Koff[-1]) - z_ifc), ddqz_z_half)
     return ddqz_z_half
 
 
@@ -127,7 +131,7 @@ def compute_ddqz_z_half(
 
 
 @field_operator
-def _compute_ddqz_z_full(
+def _compute_ddqz_z_full_and_inverse(
     z_ifc: Field[[CellDim, KDim], wpfloat],
 ) -> tuple[Field[[CellDim, KDim], wpfloat], Field[[CellDim, KDim], wpfloat]]:
     ddqz_z_full = difference_k_level_up(z_ifc)
@@ -136,7 +140,7 @@ def _compute_ddqz_z_full(
 
 
 @program(grid_type=GridType.UNSTRUCTURED)
-def compute_ddqz_z_full(
+def compute_ddqz_z_full_and_inverse(
     z_ifc: Field[[CellDim, KDim], wpfloat],
     ddqz_z_full: Field[[CellDim, KDim], wpfloat],
     inv_ddqz_z_full: Field[[CellDim, KDim], wpfloat],
@@ -161,7 +165,7 @@ def compute_ddqz_z_full(
         vertical_end: vertical end index
 
     """
-    _compute_ddqz_z_full(
+    _compute_ddqz_z_full_and_inverse(
         z_ifc,
         out=(ddqz_z_full, inv_ddqz_z_full),
         domain={CellDim: (horizontal_start, horizontal_end), KDim: (vertical_start, vertical_end)},
@@ -488,7 +492,39 @@ def compute_ddxt_z_half_e(
 
 
 @program
-def compute_ddxnt_z_full(
+def compute_ddxn_z_full(
     z_ddxnt_z_half_e: Field[[EdgeDim, KDim], float], ddxn_z_full: Field[[EdgeDim, KDim], float]
 ):
     average_edge_kdim_level_up(z_ddxnt_z_half_e, out=ddxn_z_full)
+
+
+@field_operator
+def _compute_vwind_expl_wgt(vwind_impl_wgt: Field[[CellDim], wpfloat]) -> Field[[CellDim], wpfloat]:
+    return 1.0 - vwind_impl_wgt
+
+
+@program(grid_type=GridType.UNSTRUCTURED)
+def compute_vwind_expl_wgt(
+    vwind_impl_wgt: Field[[CellDim], wpfloat],
+    vwind_expl_wgt: Field[[CellDim], wpfloat],
+    horizontal_start: int32,
+    horizontal_end: int32,
+):
+    """
+    Compute vwind_expl_wgt.
+
+    See mo_vertical_grid.f90
+
+    Args:
+        vwind_impl_wgt: offcentering in vertical mass flux
+        vwind_expl_wgt: (output) 1 - of vwind_impl_wgt
+        horizontal_start: horizontal start index
+        horizontal_end: horizontal end index
+
+    """
+
+    _compute_vwind_expl_wgt(
+        vwind_impl_wgt=vwind_impl_wgt,
+        out=vwind_expl_wgt,
+        domain={CellDim: (horizontal_start, horizontal_end)},
+    )
diff --git a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
index b45289c556..3f69c11049 100644
--- a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
+++ b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
@@ -200,6 +200,9 @@ def edge_areas(self):
     def inv_dual_edge_length(self):
         return self._get_field("inv_dual_edge_length", EdgeDim)
 
+    def dual_edge_length(self):
+        return self._get_field("dual_edge_length", EdgeDim)
+
     def edge_cell_length(self):
         return self._get_field("edge_cell_length", EdgeDim, E2CDim)
 
diff --git a/model/common/tests/metric_tests/test_metric_fields.py b/model/common/tests/metric_tests/test_metric_fields.py
index 370e5383d0..739fdbc7a6 100644
--- a/model/common/tests/metric_tests/test_metric_fields.py
+++ b/model/common/tests/metric_tests/test_metric_fields.py
@@ -19,7 +19,13 @@
 from gt4py.next.ffront.fbuiltins import int32
 
 from icon4py.model.common import constants
-from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, V2CDim, VertexDim
+from icon4py.model.common.dimension import (
+    CellDim,
+    EdgeDim,
+    KDim,
+    V2CDim,
+    VertexDim,
+)
 from icon4py.model.common.grid.horizontal import (
     HorizontalMarkerIndex,
     _compute_cells2verts,
@@ -28,13 +34,14 @@
 from icon4py.model.common.metrics.metric_fields import (
     compute_coeff_dwdz,
     compute_d2dexdz2_fac_mc,
-    compute_ddqz_z_full,
+    compute_ddqz_z_full_and_inverse,
     compute_ddqz_z_half,
+    compute_ddxn_z_full,
     compute_ddxn_z_half_e,
-    compute_ddxnt_z_full,
     compute_ddxt_z_half_e,
     compute_rayleigh_w,
     compute_scalfac_dd3d,
+    compute_vwind_expl_wgt,
     compute_z_mc,
 )
 from icon4py.model.common.test_utils.datatest_utils import (
@@ -90,7 +97,7 @@ def test_compute_ddq_z_half(icon_grid, metrics_savepoint, backend):
         pytest.skip("skipping: unsupported backend")
     ddq_z_half_ref = metrics_savepoint.ddqz_z_half()
     z_ifc = metrics_savepoint.z_ifc()
-    z_mc = zero_field(icon_grid, CellDim, KDim)
+    z_mc = zero_field(icon_grid, CellDim, KDim, extend={KDim: 1})
     nlevp1 = icon_grid.num_levels + 1
     k_index = as_field((KDim,), np.arange(nlevp1, dtype=int32))
     compute_z_mc.with_backend(backend)(
@@ -102,14 +109,14 @@ def test_compute_ddq_z_half(icon_grid, metrics_savepoint, backend):
         vertical_end=int32(icon_grid.num_levels),
         offset_provider={"Koff": icon_grid.get_offset_provider("Koff")},
     )
-    ddq_z_half = zero_field(icon_grid, CellDim, KDim, extend={KDim: 1})
+    ddqz_z_half = zero_field(icon_grid, CellDim, KDim, extend={KDim: 1})
 
     compute_ddqz_z_half.with_backend(backend=backend)(
         z_ifc=z_ifc,
         z_mc=z_mc,
         k=k_index,
         nlev=icon_grid.num_levels,
-        ddqz_z_half=ddq_z_half,
+        ddqz_z_half=ddqz_z_half,
         horizontal_start=0,
         horizontal_end=icon_grid.num_cells,
         vertical_start=0,
@@ -117,11 +124,11 @@ def test_compute_ddq_z_half(icon_grid, metrics_savepoint, backend):
         offset_provider={"Koff": icon_grid.get_offset_provider("Koff")},
     )
 
-    assert dallclose(ddq_z_half.asnumpy(), ddq_z_half_ref.asnumpy())
+    assert dallclose(ddqz_z_half.asnumpy(), ddq_z_half_ref.asnumpy())
 
 
 @pytest.mark.datatest
-def test_compute_ddqz_z_full(icon_grid, metrics_savepoint, backend):
+def test_compute_ddqz_z_full_and_inverse(icon_grid, metrics_savepoint, backend):
     if is_roundtrip(backend):
         pytest.skip("skipping: slow backend")
     z_ifc = metrics_savepoint.z_ifc()
@@ -129,7 +136,7 @@ def test_compute_ddqz_z_full(icon_grid, metrics_savepoint, backend):
     ddqz_z_full = zero_field(icon_grid, CellDim, KDim)
     inv_ddqz_z_full = zero_field(icon_grid, CellDim, KDim)
 
-    compute_ddqz_z_full.with_backend(backend)(
+    compute_ddqz_z_full_and_inverse.with_backend(backend)(
         z_ifc=z_ifc,
         ddqz_z_full=ddqz_z_full,
         inv_ddqz_z_full=inv_ddqz_z_full,
@@ -269,18 +276,97 @@ def test_compute_d2dexdz2_fac_mc(icon_grid, metrics_savepoint, grid_savepoint, b
     assert dallclose(d2dexdz2_fac2_mc_full.asnumpy(), d2dexdz2_fac2_mc_ref.asnumpy())
 
 
+@pytest.mark.datatest
+def test_compute_ddxt_z_full_e(
+    grid_savepoint, interpolation_savepoint, icon_grid, metrics_savepoint
+):
+    z_ifc = metrics_savepoint.z_ifc()
+    tangent_orientation = grid_savepoint.tangent_orientation()
+    inv_primal_edge_length = grid_savepoint.inverse_primal_edge_lengths()
+    ddxt_z_full_ref = metrics_savepoint.ddxt_z_full().asnumpy()
+    horizontal_start_vertex = icon_grid.get_start_index(
+        VertexDim,
+        HorizontalMarkerIndex.lateral_boundary(VertexDim) + 1,
+    )
+    horizontal_end_vertex = icon_grid.get_end_index(
+        VertexDim,
+        HorizontalMarkerIndex.lateral_boundary(VertexDim) - 1,
+    )
+    horizontal_start_edge = icon_grid.get_start_index(
+        EdgeDim,
+        HorizontalMarkerIndex.lateral_boundary(EdgeDim) + 2,
+    )
+    horizontal_end_edge = icon_grid.get_end_index(
+        EdgeDim,
+        HorizontalMarkerIndex.lateral_boundary(EdgeDim) - 1,
+    )
+    vertical_start = 0
+    vertical_end = icon_grid.num_levels + 1
+    cells_aw_verts = interpolation_savepoint.c_intp().asnumpy()
+    z_ifv = zero_field(icon_grid, VertexDim, KDim, extend={KDim: 1})
+    _compute_cells2verts(
+        z_ifc,
+        as_field((VertexDim, V2CDim), cells_aw_verts),
+        out=z_ifv,
+        offset_provider={"V2C": icon_grid.get_offset_provider("V2C")},
+        domain={
+            VertexDim: (horizontal_start_vertex, horizontal_end_vertex),
+            KDim: (vertical_start, vertical_end),
+        },
+    )
+    ddxt_z_half_e = zero_field(icon_grid, EdgeDim, KDim, extend={KDim: 1})
+    compute_ddxt_z_half_e(
+        z_ifv=z_ifv,
+        inv_primal_edge_length=inv_primal_edge_length,
+        tangent_orientation=tangent_orientation,
+        ddxt_z_half_e=ddxt_z_half_e,
+        horizontal_start=horizontal_start_edge,
+        horizontal_end=horizontal_end_edge,
+        vertical_start=vertical_start,
+        vertical_end=vertical_end,
+        offset_provider={"E2V": icon_grid.get_offset_provider("E2V")},
+    )
+    ddxt_z_full = zero_field(icon_grid, EdgeDim, KDim)
+    compute_ddxn_z_full(
+        z_ddxnt_z_half_e=ddxt_z_half_e,
+        ddxn_z_full=ddxt_z_full,
+        offset_provider={"Koff": icon_grid.get_offset_provider("Koff")},
+    )
+
+    assert np.allclose(ddxt_z_full.asnumpy(), ddxt_z_full_ref)
+
+
+@pytest.mark.datatest
+def test_compute_vwind_expl_wgt(icon_grid, metrics_savepoint, backend):
+    vwind_expl_wgt_full = zero_field(icon_grid, CellDim)
+    vwind_expl_wgt_ref = metrics_savepoint.vwind_expl_wgt()
+    vwind_impl_wgt = metrics_savepoint.vwind_impl_wgt()
+
+    compute_vwind_expl_wgt.with_backend(backend)(
+        vwind_impl_wgt=vwind_impl_wgt,
+        vwind_expl_wgt=vwind_expl_wgt_full,
+        horizontal_start=int32(0),
+        horizontal_end=icon_grid.num_cells,
+        offset_provider={"C2E": icon_grid.get_offset_provider("C2E")},
+    )
+
+    assert dallclose(vwind_expl_wgt_full.asnumpy(), vwind_expl_wgt_ref.asnumpy())
+
+
 @pytest.mark.datatest
 @pytest.mark.parametrize("experiment", (REGIONAL_EXPERIMENT, GLOBAL_EXPERIMENT))
 def test_compute_ddqz_z_full_e(
-    grid_savepoint, interpolation_savepoint, icon_grid, metrics_savepoint
+    grid_savepoint, interpolation_savepoint, icon_grid, metrics_savepoint, backend
 ):
+    if is_roundtrip(backend):
+        pytest.skip("skipping: slow backend")
     ddqz_z_full = as_field((CellDim, KDim), 1.0 / metrics_savepoint.inv_ddqz_z_full().asnumpy())
     c_lin_e = interpolation_savepoint.c_lin_e()
     ddqz_z_full_e_ref = metrics_savepoint.ddqz_z_full_e().asnumpy()
     vertical_start = 0
     vertical_end = icon_grid.num_levels
     ddqz_z_full_e = zero_field(icon_grid, EdgeDim, KDim)
-    compute_cells2edges(
+    compute_cells2edges.with_backend(backend)(
         p_cell_in=ddqz_z_full,
         c_int=c_lin_e,
         p_vert_out=ddqz_z_full_e,
@@ -294,9 +380,11 @@ def test_compute_ddqz_z_full_e(
 
 
 @pytest.mark.datatest
-def test_compute_ddxn_z_full_e(
-    grid_savepoint, interpolation_savepoint, icon_grid, metrics_savepoint
+def test_compute_ddxn_z_full(
+    grid_savepoint, interpolation_savepoint, icon_grid, metrics_savepoint, backend
 ):
+    if is_roundtrip(backend):
+        pytest.skip("skipping: slow backend")
     z_ifc = metrics_savepoint.z_ifc()
     inv_dual_edge_length = grid_savepoint.inv_dual_edge_length()
     ddxn_z_full_ref = metrics_savepoint.ddxn_z_full().asnumpy()
@@ -311,7 +399,7 @@ def test_compute_ddxn_z_full_e(
     vertical_start = 0
     vertical_end = icon_grid.num_levels + 1
     ddxn_z_half_e = zero_field(icon_grid, EdgeDim, KDim, extend={KDim: 1})
-    compute_ddxn_z_half_e(
+    compute_ddxn_z_half_e.with_backend(backend)(
         z_ifc=z_ifc,
         inv_dual_edge_length=inv_dual_edge_length,
         ddxn_z_half_e=ddxn_z_half_e,
@@ -322,7 +410,7 @@ def test_compute_ddxn_z_full_e(
         offset_provider={"E2C": icon_grid.get_offset_provider("E2C")},
     )
     ddxn_z_full = zero_field(icon_grid, EdgeDim, KDim)
-    compute_ddxnt_z_full(
+    compute_ddxn_z_full.with_backend(backend)(
         z_ddxnt_z_half_e=ddxn_z_half_e,
         ddxn_z_full=ddxn_z_full,
         offset_provider={"Koff": icon_grid.get_offset_provider("Koff")},
@@ -332,9 +420,11 @@ def test_compute_ddxn_z_full_e(
 
 
 @pytest.mark.datatest
-def test_compute_ddxt_z_full_e(
-    grid_savepoint, interpolation_savepoint, icon_grid, metrics_savepoint
+def test_compute_ddxt_z_full(
+    grid_savepoint, interpolation_savepoint, icon_grid, metrics_savepoint, backend
 ):
+    if is_roundtrip(backend):
+        pytest.skip("skipping: slow backend")
     z_ifc = metrics_savepoint.z_ifc()
     tangent_orientation = grid_savepoint.tangent_orientation()
     inv_primal_edge_length = grid_savepoint.inverse_primal_edge_lengths()
@@ -370,7 +460,7 @@ def test_compute_ddxt_z_full_e(
         },
     )
     ddxt_z_half_e = zero_field(icon_grid, EdgeDim, KDim, extend={KDim: 1})
-    compute_ddxt_z_half_e(
+    compute_ddxt_z_half_e.with_backend(backend)(
         z_ifv=z_ifv,
         inv_primal_edge_length=inv_primal_edge_length,
         tangent_orientation=tangent_orientation,
@@ -382,7 +472,7 @@ def test_compute_ddxt_z_full_e(
         offset_provider={"E2V": icon_grid.get_offset_provider("E2V")},
     )
     ddxt_z_full = zero_field(icon_grid, EdgeDim, KDim)
-    compute_ddxnt_z_full(
+    compute_ddxn_z_full.with_backend(backend)(
         z_ddxnt_z_half_e=ddxt_z_half_e,
         ddxn_z_full=ddxt_z_full,
         offset_provider={"Koff": icon_grid.get_offset_provider("Koff")},