Skip to content

Commit

Permalink
add Optional to backend annotation to allow None except driver initia…
Browse files Browse the repository at this point in the history
…lization
  • Loading branch information
OngChia committed Jan 17, 2025
1 parent 027f770 commit 1ec6558
Show file tree
Hide file tree
Showing 21 changed files with 96 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from enum import Enum, auto
import dataclasses
import logging
from typing import Optional

import icon4py.model.common.grid.states as grid_states
from gt4py.next import backend
from gt4py.next import backend as gtx_backend

from icon4py.model.atmosphere.advection import (
advection_states,
Expand Down Expand Up @@ -159,7 +160,7 @@ class NoAdvection(Advection):
def __init__(
self,
grid: icon_grid.IconGrid,
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
):
log.debug("advection class init - start")
Expand Down Expand Up @@ -215,7 +216,7 @@ def __init__(
vertical_advection: advection_vertical.VerticalAdvection,
grid: icon_grid.IconGrid,
metric_state: advection_states.AdvectionMetricState,
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
even_timestep: bool = False,
):
Expand Down Expand Up @@ -381,7 +382,7 @@ def convert_config_to_horizontal_vertical_advection(
metric_state: advection_states.AdvectionMetricState,
edge_params: grid_states.EdgeParams,
cell_params: grid_states.CellParams,
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
) -> tuple[advection_horizontal.HorizontalAdvection, advection_vertical.VerticalAdvection]:
match config.horizontal_advection_limiter:
Expand Down Expand Up @@ -463,7 +464,7 @@ def convert_config_to_advection(
metric_state: advection_states.AdvectionMetricState,
edge_params: grid_states.EdgeParams,
cell_params: grid_states.CellParams,
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
even_timestep: bool = False,
) -> Advection:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from abc import ABC, abstractmethod
import logging
from typing import Optional

import icon4py.model.common.grid.states as grid_states
from gt4py.next import backend
from gt4py.next import backend as gtx_backend

from icon4py.model.atmosphere.advection import advection_states

Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(
self,
grid: icon_grid.IconGrid,
interpolation_state: advection_states.AdvectionInterpolationState,
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
):
self._grid = grid
Expand Down Expand Up @@ -190,7 +191,7 @@ def __init__(
self,
grid: icon_grid.IconGrid,
least_squares_state: advection_states.AdvectionLeastSquaresState,
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
horizontal_limiter: HorizontalFluxLimiter = HorizontalFluxLimiter(),
):
self._grid = grid
Expand Down Expand Up @@ -326,7 +327,7 @@ def run(
class NoAdvection(HorizontalAdvection):
"""Class that implements disabled horizontal advection."""

def __init__(self, grid: icon_grid.IconGrid, backend: backend.Backend):
def __init__(self, grid: icon_grid.IconGrid, backend: Optional[gtx_backend.Backend]):
log.debug("horizontal advection class init - start")

# input arguments
Expand Down Expand Up @@ -440,7 +441,7 @@ def __init__(
metric_state: advection_states.AdvectionMetricState,
edge_params: grid_states.EdgeParams,
cell_params: grid_states.CellParams,
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
):
log.debug("horizontal advection class init - start")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

from abc import ABC, abstractmethod
import logging
from typing import Optional

import icon4py.model.common.grid.states as grid_states
import gt4py.next as gtx
from gt4py.next import backend
from gt4py.next import backend as gtx_backend

from icon4py.model.atmosphere.advection import advection_states

Expand Down Expand Up @@ -105,7 +106,7 @@ def run(
class NoFluxCondition(BoundaryConditions):
"""Class that sets the upper and lower boundary fluxes to zero."""

def __init__(self, grid: icon_grid.IconGrid, backend: backend.Backend):
def __init__(self, grid: icon_grid.IconGrid, backend: Optional[gtx_backend.Backend]):
# input arguments
self._grid = grid
self._backend = backend
Expand Down Expand Up @@ -186,7 +187,7 @@ def limit_fluxes(
class NoLimiter(VerticalLimiter):
"""Class that implements no vertical parabola limiter."""

def __init__(self, grid: icon_grid.IconGrid, backend: backend.Backend):
def __init__(self, grid: icon_grid.IconGrid, backend: Optional[gtx_backend.Backend]):
# input arguments
self._grid = grid
self._backend = backend
Expand Down Expand Up @@ -256,7 +257,7 @@ def limit_fluxes(
class SemiMonotonicLimiter(VerticalLimiter):
"""Class that implements a semi-monotonic vertical parabola limiter."""

def __init__(self, grid: icon_grid.IconGrid, backend: backend.Backend):
def __init__(self, grid: icon_grid.IconGrid, backend: Optional[gtx_backend.Backend]):
# input arguments
self._grid = grid
self._backend = backend
Expand Down Expand Up @@ -389,7 +390,7 @@ def run(
class NoAdvection(VerticalAdvection):
"""Class that implements disabled vertical advection."""

def __init__(self, grid: icon_grid.IconGrid, backend: backend.Backend):
def __init__(self, grid: icon_grid.IconGrid, backend: Optional[gtx_backend.Backend]):
log.debug("vertical advection class init - start")

# input arguments
Expand Down
15 changes: 11 additions & 4 deletions model/atmosphere/advection/tests/advection_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import logging
from typing import Optional

import gt4py.next as gtx
from gt4py.next import backend as gtx_backend
Expand Down Expand Up @@ -38,7 +39,7 @@ def construct_config(


def construct_interpolation_state(
savepoint: sb.InterpolationSavepoint, backend: gtx_backend.Backend
savepoint: sb.InterpolationSavepoint, backend: Optional[gtx_backend.Backend]
) -> advection_states.AdvectionInterpolationState:
return advection_states.AdvectionInterpolationState(
geofac_div=data_alloc.as_1D_sparse_field(
Expand All @@ -60,7 +61,7 @@ def construct_least_squares_state(


def construct_metric_state(
icon_grid, savepoint: sb.MetricSavepoint, backend: gtx_backend.Backend
icon_grid, savepoint: sb.MetricSavepoint, backend: Optional[gtx_backend.Backend]
) -> advection_states.AdvectionMetricState:
constant_f = data_alloc.constant_field(icon_grid, 1.0, dims.KDim, backend=backend)
ddqz_z_full_np = np.reciprocal(savepoint.inv_ddqz_z_full().asnumpy())
Expand All @@ -73,7 +74,10 @@ def construct_metric_state(


def construct_diagnostic_init_state(
icon_grid, savepoint: sb.AdvectionInitSavepoint, ntracer: int, backend: gtx_backend.Backend
icon_grid,
savepoint: sb.AdvectionInitSavepoint,
ntracer: int,
backend: Optional[gtx_backend.Backend],
) -> advection_states.AdvectionDiagnosticState:
return advection_states.AdvectionDiagnosticState(
airmass_now=savepoint.airmass_now(),
Expand All @@ -89,7 +93,10 @@ def construct_diagnostic_init_state(


def construct_diagnostic_exit_state(
icon_grid, savepoint: sb.AdvectionInitSavepoint, ntracer: int, backend: gtx_backend.Backend
icon_grid,
savepoint: sb.AdvectionInitSavepoint,
ntracer: int,
backend: Optional[gtx_backend.Backend],
) -> advection_states.AdvectionDiagnosticState:
zero_f = data_alloc.allocate_zero_field(
dims.CellDim, dims.KDim, grid=icon_grid, backend=backend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import logging
import math
import sys
from typing import Final
from typing import Final, Optional

import gt4py.next as gtx
import icon4py.model.common.grid.states as grid_states
from gt4py.next import int32

import icon4py.model.common.states.prognostic_state as prognostics
from gt4py.next import backend
from gt4py.next import backend as gtx_backend

from icon4py.model.atmosphere.diffusion import diffusion_utils, diffusion_states
from icon4py.model.atmosphere.diffusion.diffusion_utils import (
Expand Down Expand Up @@ -363,7 +363,7 @@ def __init__(
interpolation_state: diffusion_states.DiffusionInterpolationState,
edge_params: grid_states.EdgeParams,
cell_params: grid_states.CellParams,
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
orchestration: bool = False,
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Final, Optional

import gt4py.next as gtx
from gt4py.next import backend
from gt4py.next import backend as gtx_backend

import icon4py.model.atmosphere.dycore.solve_nonhydro_stencils as nhsolve_stencils
import icon4py.model.common.grid.states as grid_states
Expand Down Expand Up @@ -248,7 +248,7 @@ class IntermediateFields:
def allocate(
cls,
grid: grid_def.BaseGrid,
backend: Optional[backend.Backend] = None,
backend: Optional[gtx_backend.Backend] = None,
):
return IntermediateFields(
z_gradh_exner=data_alloc.allocate_zero_field(
Expand Down Expand Up @@ -448,7 +448,7 @@ def __init__(
edge_geometry: grid_states.EdgeParams,
cell_geometry: grid_states.CellParams,
owner_mask: fa.CellField[bool],
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
):
self._exchange = exchange
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import Optional

import gt4py.next as gtx
from gt4py.next import backend
from gt4py.next import backend as gtx_backend

import icon4py.model.atmosphere.dycore.velocity_advection_stencils as velocity_stencils
from icon4py.model.atmosphere.dycore import dycore_states
Expand Down Expand Up @@ -63,7 +65,7 @@ def __init__(
vertical_params: v_grid.VerticalGrid,
edge_params: grid_states.EdgeParams,
owner_mask: fa.CellField[bool],
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
):
self.grid: icon_grid.IconGrid = grid
self._backend = backend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
from typing import Final
from typing import Final, Optional

import gt4py.next as gtx
import numpy as np
from gt4py.eve.utils import FrozenNamespace
from gt4py.next import backend, broadcast
from gt4py.next import backend as gtx_backend, broadcast
from gt4py.next.ffront.fbuiltins import (
abs,
exp,
Expand Down Expand Up @@ -226,7 +226,7 @@ def __init__(
grid: icon_grid.IconGrid,
vertical_params: v_grid.VerticalGrid,
metric_state: MetricStateSaturationAdjustment,
backend: backend.Backend,
backend: Optional[gtx_backend.Backend],
):
self._backend = backend
self.config = config
Expand Down
2 changes: 1 addition & 1 deletion model/common/src/icon4py/model/common/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self,
grid: icon.IconGrid,
decomposition_info: definitions.DecompositionInfo,
backend: gtx_backend.Backend,
backend: Optional[gtx_backend.Backend],
coordinates: gm.CoordinateDict,
extra_fields: dict[InputGeometryFieldType, gtx.Field],
metadata: dict[str, model.FieldMetaData],
Expand Down
2 changes: 1 addition & 1 deletion model/common/src/icon4py/model/common/grid/grid_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def _read_start_end_indices(

def _read_grid_refinement_fields(
self,
backend: gtx_backend.Backend,
backend: Optional[gtx_backend.Backend],
decomposition_info: Optional[decomposition.DecompositionInfo] = None,
) -> tuple[dict[dims.Dimension : data_alloc.NDArray]]:
"""
Expand Down
8 changes: 4 additions & 4 deletions model/common/src/icon4py/model/common/grid/vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import math
import pathlib
from typing import Final
from typing import Final, Optional

import gt4py.next as gtx
import numpy as np
Expand Down Expand Up @@ -275,7 +275,7 @@ def _determine_end_index_of_flat_layers(


def _read_vct_a_and_vct_b_from_file(
file_path: pathlib.Path, num_levels: int, backend: gtx_backend.Backend
file_path: pathlib.Path, num_levels: int, backend: Optional[gtx_backend.Backend]
) -> tuple[fa.KField, fa.KField]:
"""
Read vct_a and vct_b from a file.
Expand Down Expand Up @@ -321,7 +321,7 @@ def _read_vct_a_and_vct_b_from_file(


def _compute_vct_a_and_vct_b(
vertical_config: VerticalGridConfig, backend: gtx_backend.Backend
vertical_config: VerticalGridConfig, backend: Optional[gtx_backend.Backend]
) -> tuple[fa.KField, fa.KField]:
"""
Compute vct_a and vct_b.
Expand Down Expand Up @@ -507,7 +507,7 @@ def _compute_vct_a_and_vct_b(


def get_vct_a_and_vct_b(
vertical_config: VerticalGridConfig, backend: gtx_backend.Backend
vertical_config: VerticalGridConfig, backend: Optional[gtx_backend.Backend]
) -> tuple[fa.KField, fa.KField]:
"""
get vct_a and vct_b.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import functools
from typing import Optional

import gt4py.next as gtx
from gt4py.next import backend as gtx_backend
Expand Down Expand Up @@ -36,7 +37,7 @@ def __init__(
grid: icon.IconGrid,
decomposition_info: definitions.DecompositionInfo,
geometry_source: geometry.GridGeometry,
backend: gtx_backend.Backend,
backend: Optional[gtx_backend.Backend],
metadata: dict[str, model.FieldMetaData],
):
self._backend = backend
Expand Down
Loading

0 comments on commit 1ec6558

Please sign in to comment.