Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test GTIR-DaCe backend #638

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f346ef1
Use new itir.Program everywhere
tehrengruber Nov 14, 2024
a80a1e0
Merge remote-tracking branch 'origin/main' into update_to_gtir
tehrengruber Dec 6, 2024
61e97e4
Use gt4py main again
tehrengruber Dec 6, 2024
416d7e7
fix connectivities
havogt Dec 9, 2024
38a162c
switch gt4py branch
havogt Dec 9, 2024
21c495d
fix more connectivities
edopao Dec 20, 2024
0e8fcba
fix more connectivities (1)
edopao Dec 20, 2024
d784012
Merge remote-tracking branch 'origin/main' into update_to_gtir
edopao Dec 20, 2024
7917c51
update versions (temporarily)
havogt Jan 8, 2025
e7e5a03
Merge remote-tracking branch 'origin/main' into update_to_gtir_dace
edopao Jan 10, 2025
64b12d3
switch gt4py branch to dace-gtir-scan
edopao Jan 10, 2025
150d48c
update uv lock
edopao Jan 10, 2025
c26c235
update dace CI-config
edopao Jan 10, 2025
b627c71
pytest marker for tests that require concat_where
edopao Jan 10, 2025
b80e5b7
update lock file
edopao Jan 10, 2025
042f7dd
disable orchestration tests
edopao Jan 10, 2025
78d81ce
disable orchestration tests (1)
edopao Jan 10, 2025
84f1ec7
Merge remote-tracking branch 'origin/main' into update_to_gtir_dace
edopao Jan 10, 2025
363cc2a
update uv lock
edopao Jan 13, 2025
6481b65
enable orchestration test cases in diffusion module
edopao Jan 13, 2025
e78cf66
fix parameterization of dace orchestration
edopao Jan 13, 2025
d15c2f6
update uv lock
edopao Jan 13, 2025
f0eca78
Fix some errors in orchestration decorator
edopao Jan 13, 2025
8307f04
update orchestrator decorator for gtir
edopao Jan 13, 2025
242850e
update uv lock
edopao Jan 13, 2025
3624cd2
DaCe Orchestration: WIP
kotsaloscv Jan 14, 2025
6c57fb2
workaround for compile_time_connectivities
edopao Jan 14, 2025
e78f239
add gt4py_cache to gitignore
edopao Jan 14, 2025
95eef9c
update uv lock
edopao Jan 14, 2025
195cba6
update dace check for diffusion orchestration
edopao Jan 15, 2025
eb50015
update uv lock
edopao Jan 15, 2025
ad1574d
update uv lock
edopao Jan 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ simple_mesh*.nc
**/docs/_source/*.rst

### GT4Py ####
.gt_cache/
.gt4py_cache/

# DaCe
.dacecache
Expand Down
15 changes: 9 additions & 6 deletions ci/dace.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ include:
.test_model_stencils:
stage: test
script:
- nox -s test_model_stencils-3.10 -- --backend=$BACKEND --grid=$GRID
- nox -s "test_model_stencils-3.10(subpackage='$COMPONENT')" -- --backend=$BACKEND --grid=$GRID
parallel:
matrix:
- BACKEND: [dace_cpu, dace_gpu]
GRID: [simple_grid, icon_grid]
# TODO(edopao): Add more components once they work fine with DaCe
COMPONENT: [atmosphere/diffusion, atmosphere/dycore]
test_model_stencils_x86_64:
extends: [.test_model_stencils, .test_template_x86_64]
test_model_stencils_aarch64:
Expand All @@ -20,9 +22,9 @@ test_model_stencils_aarch64:
- nox -s "test_model_datatest-3.10(subpackage='$COMPONENT')" -- --backend=$BACKEND
parallel:
matrix:
# TODO(edopao): Add more components once they work fine with DaCe
- COMPONENT: [atmosphere/diffusion, atmosphere/dycore]
BACKEND: [dace_cpu]
- BACKEND: [dace_cpu]
# TODO(edopao): Add more components once they work fine with DaCe
COMPONENT: [atmosphere/diffusion, atmosphere/dycore]
test_model_datatests_x86_64:
extends: [.test_model_datatests, .test_template_x86_64]
test_model_datatests_aarch64:
Expand All @@ -31,12 +33,13 @@ test_model_datatests_aarch64:
.benchmark_model_stencils:
stage: benchmark
script:
# force execution of tests where validation is expected to fail, because the reason for failure is wrong numpy reference
- nox -s benchmark_model-3.10 -- --backend=$BACKEND --grid=$GRID --runxfail
- nox -s "benchmark_model-3.10(subpackage='$COMPONENT')" -- --backend=$BACKEND --grid=$GRID
parallel:
matrix:
- BACKEND: [dace_cpu, dace_gpu]
GRID: [icon_grid, icon_grid_global]
# TODO(edopao): Add more components once they work fine with DaCe
COMPONENT: [atmosphere/diffusion, atmosphere/dycore]
benchmark_model_stencils_x86_64:
extends: [.benchmark_model_stencils, .test_template_x86_64]
benchmark_model_stencils_aarch64:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,9 @@ def __init__(

self._determine_horizontal_domains()

self.compile_time_connectivities = dace_orchestration.build_compile_time_connectivities(
self._grid.offset_providers
)
# TODO(edopao): we should call gtx.common.offset_provider_to_type()
# but this requires some changes in type inference.
self.compile_time_connectivities = self._grid.offset_providers

def _allocate_temporary_fields(self):
self.diff_multfac_vn = data_alloc.allocate_zero_field(
Expand Down Expand Up @@ -925,6 +925,7 @@ def orchestration_uid(self) -> str:
"_backend",
"_exchange",
"_grid",
"compile_time_connectivities",
*[
name
for name in self.__dict__.keys()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def reference(
**kwargs,
) -> dict:
c2e = grid.connectivities[dims.C2EDim]
c2ce = grid.get_offset_provider("C2CE").table
c2ce = grid.get_offset_provider("C2CE").ndarray

geofac_div = np.expand_dims(geofac_div, axis=-1)
e_bln_c_s = np.expand_dims(e_bln_c_s, axis=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from icon4py.model.common.decomposition import definitions
from icon4py.model.common.grid import vertical as v_grid
from icon4py.model.common.utils import data_allocation as data_alloc
from icon4py.model.testing import datatest_utils, parallel_helpers
from icon4py.model.testing import datatest_utils, helpers, parallel_helpers

from .. import utils

Expand All @@ -22,7 +22,7 @@
@pytest.mark.parametrize("experiment", [datatest_utils.REGIONAL_EXPERIMENT])
@pytest.mark.parametrize("ndyn_substeps", [2])
@pytest.mark.parametrize("linit", [True, False])
@pytest.mark.parametrize("orchestration", [True, False])
@pytest.mark.parametrize("orchestration", [False, True])
def test_parallel_diffusion(
experiment,
step_date_init,
Expand All @@ -44,7 +44,7 @@ def test_parallel_diffusion(
backend,
orchestration,
):
if orchestration and ("dace" not in backend.name.lower()):
if orchestration and not helpers.is_dace(backend):
raise pytest.skip("This test is only executed for `dace backends.")
caplog.set_level("INFO")
parallel_helpers.check_comm_size(processor_props)
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_parallel_diffusion_multiple_steps(
caplog,
backend,
):
if "dace" not in backend.name.lower():
if not helpers.is_dace(backend):
raise pytest.skip("This test is only executed for `dace backends.")
######################################################################
# Diffusion initialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def test_verify_diffusion_init_against_savepoint(
(dt_utils.GLOBAL_EXPERIMENT, "2000-01-01T00:00:02.000", "2000-01-01T00:00:02.000"),
],
)
@pytest.mark.parametrize("ndyn_substeps, orchestration", [(2, [True, False])])
@pytest.mark.parametrize("ndyn_substeps", [2])
@pytest.mark.parametrize("orchestration", [False, True])
def test_run_diffusion_single_step(
savepoint_diffusion_init,
savepoint_diffusion_exit,
Expand All @@ -403,7 +404,7 @@ def test_run_diffusion_single_step(
backend,
orchestration,
):
if orchestration and ("dace" not in backend.name.lower()):
if orchestration and not helpers.is_dace(backend):
pytest.skip(f"running backend = '{backend.name}': orchestration only on dace backends")
grid = get_grid_for_experiment(experiment, backend)
cell_geometry = get_cell_geometry_for_experiment(experiment, backend)
Expand Down Expand Up @@ -502,8 +503,8 @@ def test_run_diffusion_multiple_steps(
backend,
icon_grid,
):
if "dace" not in backend.name.lower():
raise pytest.skip("This test is only executed for DaCe backends.")
if not helpers.is_dace(backend):
pytest.skip(f"running backend = '{backend.name}': orchestration only on dace backends")
######################################################################
# Diffusion initialization
######################################################################
Expand Down Expand Up @@ -624,7 +625,8 @@ def test_run_diffusion_multiple_steps(

@pytest.mark.datatest
@pytest.mark.parametrize("experiment", [dt_utils.REGIONAL_EXPERIMENT])
@pytest.mark.parametrize("linit, orchestration", [(True, [True, False])])
@pytest.mark.parametrize("linit", [True, True])
@pytest.mark.parametrize("orchestration", [False, True])
def test_run_diffusion_initial_step(
experiment,
linit,
Expand All @@ -639,7 +641,7 @@ def test_run_diffusion_initial_step(
backend,
orchestration,
):
if orchestration and ("dace" not in backend.name.lower()):
if orchestration and not helpers.is_dace(backend):
pytest.skip(f"running backend = '{backend.name}': orchestration only on dace backends")
grid = get_grid_for_experiment(experiment, backend)
cell_geometry = get_cell_geometry_for_experiment(experiment, backend)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_verify_geofac_n2s_field_manipulation(interpolation_savepoint, icon_grid
geofac_c = interpolation_state.geofac_n2s_c.asnumpy()
geofac_nbh = interpolation_state.geofac_n2s_nbh.asnumpy()
assert np.count_nonzero(geofac_nbh) > 0
cec_table = icon_grid.get_offset_provider("C2CEC").table
cec_table = icon_grid.get_offset_provider("C2CEC").ndarray
assert np.allclose(geofac_c, geofac_n2s[:, 0])
assert geofac_nbh[cec_table].shape == geofac_n2s[:, 1:].shape
assert np.allclose(geofac_nbh[cec_table], geofac_n2s[:, 1:])
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import gt4py.next as gtx
from gt4py.next.ffront.decorator import GridType, field_operator, program
from gt4py.next.ffront.fbuiltins import where

from icon4py.model.atmosphere.dycore.stencils.compute_contravariant_correction_of_w import (
Expand All @@ -19,7 +18,7 @@
from icon4py.model.common.type_alias import vpfloat, wpfloat


@field_operator
@gtx.field_operator
def _fused_solve_nonhydro_stencil_39_40(
e_bln_c_s: gtx.Field[gtx.Dims[dims.CEDim], wpfloat],
z_w_concorr_me: fa.EdgeKField[vpfloat],
Expand All @@ -39,7 +38,7 @@ def _fused_solve_nonhydro_stencil_39_40(
return w_concorr_c


@program(grid_type=GridType.UNSTRUCTURED)
@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED)
def fused_solve_nonhydro_stencil_39_40(
e_bln_c_s: gtx.Field[gtx.Dims[dims.CEDim], wpfloat],
z_w_concorr_me: fa.EdgeKField[vpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def add_interpolated_horizontal_advection_of_w_numpy(
grid, e_bln_c_s: np.array, z_v_grad_w: np.array, ddt_w_adv: np.array, **kwargs
) -> np.array:
e_bln_c_s = np.expand_dims(e_bln_c_s, axis=-1)
c2ce = grid.get_offset_provider("C2CE").table
c2ce = grid.get_offset_provider("C2CE").ndarray

ddt_w_adv = ddt_w_adv + np.sum(
z_v_grad_w[grid.connectivities[dims.C2EDim]] * e_bln_c_s[c2ce],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def reference(
) -> tuple[np.array]:
c2e = grid.connectivities[dims.C2EDim]
geofac_div = np.expand_dims(geofac_div, axis=-1)
c2ce = grid.get_offset_provider("C2CE").table
c2ce = grid.get_offset_provider("C2CE").ndarray

z_flxdiv_mass = np.sum(
geofac_div[c2ce] * mass_fl_e[c2e],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class TestFusedVelocityAdvectionStencil15To18(StencilTest):
"z_w_con_c_full",
"ddt_w_adv",
)
MARKER = (pytest.mark.requires_concat_where,)

@staticmethod
def _fused_velocity_advection_stencil_16_to_18(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def interpolate_to_cell_center_numpy(
grid, interpolant: np.array, e_bln_c_s: np.array, **kwargs
) -> np.array:
e_bln_c_s = np.expand_dims(e_bln_c_s, axis=-1)
c2ce = grid.get_offset_provider("C2CE").table
c2ce = grid.get_offset_provider("C2CE").ndarray

interpolation = np.sum(
interpolant[grid.connectivities[dims.C2EDim]] * e_bln_c_s[c2ce],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause
import gt4py.next as gtx
import numpy as np
from gt4py.next.program_processors.runners.gtfn import run_gtfn
from gt4py.next.program_processors.runners.dace import run_dace_cpu

from icon4py.model.atmosphere.dycore.stencils.solve_tridiagonal_matrix_for_w_forward_sweep import (
solve_tridiagonal_matrix_for_w_forward_sweep,
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_solve_tridiagonal_matrix_for_w_forward_sweep():
v_start = 1
v_end = gtx.int32(grid.num_levels)
# TODO we run this test with the C++ backend as the `embedded` backend doesn't handle this pattern
solve_tridiagonal_matrix_for_w_forward_sweep.with_backend(run_gtfn)(
solve_tridiagonal_matrix_for_w_forward_sweep.with_backend(run_dace_cpu)(
vwind_impl_wgt=vwind_impl_wgt,
theta_v_ic=theta_v_ic,
ddqz_z_half=ddqz_z_half,
Expand Down
2 changes: 1 addition & 1 deletion model/common/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ version = "0.0.6"
all = ["icon4py-common[dace,distributed,io]"]
cuda11 = ['cupy-cuda11x>=13.0', 'gt4py[cuda11]']
cuda12 = ['cupy-cuda12x>=13.0', 'gt4py[cuda12]']
dace = ["dace<1.0", "gt4py[dace]"] # TODO(egparedes): DaCe max constraint should be transformed to min constraint after updating gt4py
dace = ["dace>=1.0", "gt4py[dace]"]
distributed = ["ghex>=0.3.0", "mpi4py>=3.1.5"]
io = [
# external dependencies
Expand Down
10 changes: 5 additions & 5 deletions model/common/src/icon4py/model/common/grid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,18 @@ def _get_offset_provider(self, dim, from_dim, to_dim):
), 'Neighbor table\'s "{}" data type must be gtx.int32. Instead it\'s "{}"'.format(
dim, self.connectivities[dim].dtype
)
return gtx.NeighborTableOffsetProvider(
self.connectivities[dim],
from_dim,
return gtx.as_connectivity(
[from_dim, dim],
to_dim,
self.size[dim],
has_skip_values=self._has_skip_values(dim),
self.connectivities[dim],
skip_value=-1 if self._has_skip_values(dim) else None,
)

def _get_offset_provider_for_sparse_fields(self, dim, from_dim, to_dim):
if dim not in self.connectivities:
raise MissingConnectivity()
return grid_utils.neighbortable_offset_provider_for_1d_sparse_fields(
dim,
self.connectivities[dim].shape,
from_dim,
to_dim,
Expand Down
12 changes: 6 additions & 6 deletions model/common/src/icon4py/model/common/grid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
# SPDX-License-Identifier: BSD-3-Clause
import gt4py.next as gtx
import numpy as np
from gt4py.next import Dimension, NeighborTableOffsetProvider
from gt4py.next import Dimension


def neighbortable_offset_provider_for_1d_sparse_fields(
dim: Dimension,
old_shape: tuple[int, int],
origin_axis: Dimension,
neighbor_axis: Dimension,
Expand All @@ -22,10 +23,9 @@ def neighbortable_offset_provider_for_1d_sparse_fields(
), 'Neighbor table\'s ("{}" to "{}") data type for 1d sparse fields must be gtx.int32. Instead it\'s "{}"'.format(
origin_axis, neighbor_axis, table.dtype
)
return NeighborTableOffsetProvider(
table,
origin_axis,
return gtx.as_connectivity(
[origin_axis, dim],
neighbor_axis,
table.shape[1],
has_skip_values=has_skip_values,
table,
skip_value=-1 if has_skip_values else None,
)
31 changes: 4 additions & 27 deletions model/common/src/icon4py/model/common/orchestration/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,6 @@ def wrapper(*args, **kwargs):
else:
return fuse_func(*args, **kwargs)

# Pytest does not clear the cache between runs in a proper way -pytest.mark.parametrize(...)-.
# This leads to corrupted cache and subsequent errors.
# To avoid this, we provide a way to clear the cache.
def clear_cache():
orchestrator_cache.clear()

wrapper.clear_cache = clear_cache

return wrapper

return _decorator(func) if func else _decorator
Expand Down Expand Up @@ -320,21 +312,6 @@ def wait(comm_handle: Union[int, decomposition.ExchangeResult]):
comm_handle.wait()


def build_compile_time_connectivities(
offset_providers: dict[str, gtx.common.Connectivity],
) -> dict[str, gtx.common.Connectivity]:
connectivities = {}
for k, v in offset_providers.items():
if hasattr(v, "table"):
connectivities[k] = gtx.otf.arguments.CompileTimeConnectivity(
v.max_neighbors, v.has_skip_values, v.origin_axis, v.neighbor_axis, v.table.dtype
)
else:
connectivities[k] = v

return connectivities


if dace:

def to_dace_annotations(fuse_func: Callable) -> dict[str, Any]:
Expand Down Expand Up @@ -547,9 +524,9 @@ def dace_specific_kwargs(
return {
# connectivity tables at runtime
**{
connectivity_identifier(k): v.table
connectivity_identifier(k): v.ndarray
for k, v in offset_providers.items()
if hasattr(v, "table")
if hasattr(v, "ndarray")
},
# GHEX C++ ptrs
"__context_ptr": expose_cpp_ptr(exchange_obj._context)
Expand Down Expand Up @@ -638,8 +615,8 @@ def _concretize_symbols_for_dace_structure(dace_cls, orig_cls):

return {
**{
"CellDim_sym": grid.offset_providers["C2E"].table.shape[0],
"EdgeDim_sym": grid.offset_providers["E2C"].table.shape[0],
"CellDim_sym": grid.offset_providers["C2E"].ndarray.shape[0],
"EdgeDim_sym": grid.offset_providers["E2C"].ndarray.shape[0],
"KDim_sym": grid.num_levels,
},
**concretize_symbols_for_dace_structure,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.SDFG:
sdfg = dace.SDFG("DummyNestedSDFG")
state = sdfg.add_state()

sdfg.add_scalar(name="__return", dtype=dace.int32)
sdfg.add_array(name="__return", shape=[1], dtype=dace.int32)

tasklet = dace.sdfg.nodes.Tasklet(
"DummyNestedSDFG",
Expand Down
Loading
Loading