Skip to content

Commit

Permalink
Miscellaneous lazy preprocessor improvements (#2520)
Browse files Browse the repository at this point in the history
Co-authored-by: Valeriu Predoi <valeriu.predoi@gmail.com>
  • Loading branch information
bouweandela and valeriupredoi authored Nov 25, 2024
1 parent 11e6aa8 commit d0bfb58
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 177 deletions.
66 changes: 41 additions & 25 deletions esmvalcore/preprocessor/_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from iris.exceptions import CoordinateNotFoundError

from esmvalcore.preprocessor._shared import (
apply_mask,
get_dims_along_axes,
get_iris_aggregator,
get_normalized_cube,
preserve_float_dtype,
Expand Down Expand Up @@ -188,8 +190,8 @@ def _extract_irregular_region(
cube = cube[..., i_slice, j_slice]
selection = selection[i_slice, j_slice]
# Mask remaining coordinates outside region
mask = da.broadcast_to(~selection, cube.shape)
cube.data = da.ma.masked_where(mask, cube.core_data())
horizontal_dims = get_dims_along_axes(cube, ["X", "Y"])
cube.data = apply_mask(~selection, cube.core_data(), horizontal_dims)
return cube


Expand Down Expand Up @@ -857,31 +859,45 @@ def _mask_cube(cube: Cube, masks: dict[str, np.ndarray]) -> Cube:
_cube.add_aux_coord(
AuxCoord(id_, units="no_unit", long_name="shape_id")
)
mask = da.broadcast_to(mask, _cube.shape)
_cube.data = da.ma.masked_where(~mask, _cube.core_data())
horizontal_dims = get_dims_along_axes(cube, axes=["X", "Y"])
_cube.data = apply_mask(~mask, _cube.core_data(), horizontal_dims)
cubelist.append(_cube)
result = fix_coordinate_ordering(cubelist.merge_cube())
if cube.cell_measures():
for measure in cube.cell_measures():
# Cell measures that are time-dependent, with 4 dimension and
# an original shape of (time, depth, lat, lon), need to be
# broadcasted to the cube with 5 dimensions and shape
# (time, shape_id, depth, lat, lon)
if measure.ndim > 3 and result.ndim > 4:
data = measure.core_data()
data = da.expand_dims(data, axis=(1,))
data = da.broadcast_to(data, result.shape)
measure = iris.coords.CellMeasure(
for measure in cube.cell_measures():
# Cell measures that are time-dependent, with 4 dimension and
# an original shape of (time, depth, lat, lon), need to be
# broadcast to the cube with 5 dimensions and shape
# (time, shape_id, depth, lat, lon)
if measure.ndim > 3 and result.ndim > 4:
data = measure.core_data()
if result.has_lazy_data():
# Make the cell measure lazy if the result is lazy.
cube_chunks = cube.lazy_data().chunks
chunk_dims = cube.cell_measure_dims(measure)
data = da.asarray(
data,
standard_name=measure.standard_name,
long_name=measure.long_name,
units=measure.units,
measure=measure.measure,
var_name=measure.var_name,
attributes=measure.attributes,
chunks=tuple(cube_chunks[i] for i in chunk_dims),
)
add_cell_measure(result, measure, measure.measure)
if cube.ancillary_variables():
for ancillary_variable in cube.ancillary_variables():
add_ancillary_variable(result, ancillary_variable)
chunks = result.lazy_data().chunks
else:
chunks = None
dim_map = get_dims_along_axes(result, ["T", "Z", "Y", "X"])
data = iris.util.broadcast_to_shape(
data,
result.shape,
dim_map=dim_map,
chunks=chunks,
)
measure = iris.coords.CellMeasure(
data,
standard_name=measure.standard_name,
long_name=measure.long_name,
units=measure.units,
measure=measure.measure,
var_name=measure.var_name,
attributes=measure.attributes,
)
add_cell_measure(result, measure, measure.measure)
for ancillary_variable in cube.ancillary_variables():
add_ancillary_variable(result, ancillary_variable)
return result
2 changes: 2 additions & 0 deletions esmvalcore/preprocessor/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from esmvalcore.cmor.check import CheckLevels
from esmvalcore.esgf.facets import FACETS
from esmvalcore.iris_helpers import merge_cube_attributes
from esmvalcore.preprocessor._shared import _rechunk_aux_factory_dependencies

from .._task import write_ncl_settings

Expand Down Expand Up @@ -392,6 +393,7 @@ def concatenate(cubes, check_level=CheckLevels.DEFAULT):
cubes = _sort_cubes_by_time(cubes)
_fix_calendars(cubes)
cubes = _check_time_overlaps(cubes)
cubes = [_rechunk_aux_factory_dependencies(cube) for cube in cubes]
result = _concatenate_cubes(cubes, check_level=check_level)

if len(result) == 1:
Expand Down
32 changes: 5 additions & 27 deletions esmvalcore/preprocessor/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

import logging
import os
from collections.abc import Iterable
from typing import Literal, Optional
from typing import Literal

import cartopy.io.shapereader as shpreader
import dask.array as da
Expand All @@ -22,7 +21,7 @@
from iris.cube import Cube
from iris.util import rolling_window

from esmvalcore.preprocessor._shared import get_array_module
from esmvalcore.preprocessor._shared import apply_mask

from ._supplementary_vars import register_supplementaries

Expand Down Expand Up @@ -61,24 +60,6 @@ def _get_fx_mask(
return inmask


def _apply_mask(
mask: np.ndarray | da.Array,
array: np.ndarray | da.Array,
dim_map: Optional[Iterable[int]] = None,
) -> np.ndarray | da.Array:
"""Apply a (broadcasted) mask on an array."""
npx = get_array_module(mask, array)
if dim_map is not None:
if isinstance(array, da.Array):
chunks = array.chunks
else:
chunks = None
mask = iris.util.broadcast_to_shape(
mask, array.shape, dim_map, chunks=chunks
)
return npx.ma.masked_where(mask, array)


@register_supplementaries(
variables=["sftlf", "sftof"],
required="prefer_at_least_one",
Expand Down Expand Up @@ -145,7 +126,7 @@ def mask_landsea(cube: Cube, mask_out: Literal["land", "sea"]) -> Cube:
landsea_mask = _get_fx_mask(
ancillary_var.core_data(), mask_out, ancillary_var.var_name
)
cube.data = _apply_mask(
cube.data = apply_mask(
landsea_mask,
cube.core_data(),
cube.ancillary_variable_dims(ancillary_var),
Expand Down Expand Up @@ -212,7 +193,7 @@ def mask_landseaice(cube: Cube, mask_out: Literal["landsea", "ice"]) -> Cube:
landseaice_mask = _get_fx_mask(
ancillary_var.core_data(), mask_out, ancillary_var.var_name
)
cube.data = _apply_mask(
cube.data = apply_mask(
landseaice_mask,
cube.core_data(),
cube.ancillary_variable_dims(ancillary_var),
Expand Down Expand Up @@ -350,10 +331,7 @@ def _mask_with_shp(cube, shapefilename, region_indices=None):
else:
mask |= shp_vect.contains(region, x_p_180, y_p_90)

if cube.has_lazy_data():
mask = da.array(mask)

cube.data = _apply_mask(
cube.data = apply_mask(
mask,
cube.core_data(),
cube.coord_dims("latitude") + cube.coord_dims("longitude"),
Expand Down
31 changes: 1 addition & 30 deletions esmvalcore/preprocessor/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from esmvalcore.exceptions import ESMValCoreDeprecationWarning
from esmvalcore.iris_helpers import has_irregular_grid, has_unstructured_grid
from esmvalcore.preprocessor._shared import (
_rechunk_aux_factory_dependencies,
get_array_module,
get_dims_along_axes,
preserve_float_dtype,
Expand Down Expand Up @@ -1174,36 +1175,6 @@ def parse_vertical_scheme(scheme):
return scheme, extrap_scheme


def _rechunk_aux_factory_dependencies(
cube: iris.cube.Cube,
coord_name: str,
) -> iris.cube.Cube:
"""Rechunk coordinate aux factory dependencies.
This ensures that the resulting coordinate has reasonably sized
chunks that are aligned with the cube data for optimal computational
performance.
"""
# Workaround for https://github.com/SciTools/iris/issues/5457
try:
factory = cube.aux_factory(coord_name)
except iris.exceptions.CoordinateNotFoundError:
return cube

cube = cube.copy()
cube_chunks = cube.lazy_data().chunks
for coord in factory.dependencies.values():
coord_dims = cube.coord_dims(coord)
if coord_dims:
coord = coord.copy()
chunks = tuple(cube_chunks[i] for i in coord_dims)
coord.points = coord.lazy_points().rechunk(chunks)
if coord.has_bounds():
coord.bounds = coord.lazy_bounds().rechunk(chunks + (None,))
cube.replace_coord(coord)
return cube


@preserve_float_dtype
def extract_levels(
cube: iris.cube.Cube,
Expand Down
78 changes: 78 additions & 0 deletions esmvalcore/preprocessor/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,81 @@ def get_dims_along_coords(
"""Get a tuple with the dimensions along one or more coordinates."""
dims = {d for coord in coords for d in _get_dims_along(cube, coord)}
return tuple(sorted(dims))


def apply_mask(
mask: np.ndarray | da.Array,
array: np.ndarray | da.Array,
dim_map: Iterable[int],
) -> np.ma.MaskedArray | da.Array:
"""Apply a (broadcasted) mask on an array.
Parameters
----------
mask:
The mask to apply to array.
array:
The array to mask out.
dim_map :
A mapping of the dimensions of *mask* to their corresponding
dimension in *array*.
See :func:`iris.util.broadcast_to_shape` for additional details.
Returns
-------
np.ma.MaskedArray or da.Array:
A copy of the input array with the mask applied.
"""
if isinstance(array, da.Array):
array_chunks = array.chunks
# If the mask is not a Dask array yet, we make it into a Dask array
# before broadcasting to avoid inserting a large array into the Dask
# graph.
mask_chunks = tuple(array_chunks[i] for i in dim_map)
mask = da.asarray(mask, chunks=mask_chunks)
else:
array_chunks = None

mask = iris.util.broadcast_to_shape(
mask, array.shape, dim_map=dim_map, chunks=array_chunks
)

array_module = get_array_module(mask, array)
return array_module.ma.masked_where(mask, array)


def _rechunk_aux_factory_dependencies(
cube: iris.cube.Cube,
coord_name: str | None = None,
) -> iris.cube.Cube:
"""Rechunk coordinate aux factory dependencies.
This ensures that the resulting coordinate has reasonably sized
chunks that are aligned with the cube data for optimal computational
performance.
"""
# Workaround for https://github.com/SciTools/iris/issues/5457
if coord_name is None:
factories = cube.aux_factories
else:
try:
factories = [cube.aux_factory(coord_name)]
except iris.exceptions.CoordinateNotFoundError:
return cube

cube = cube.copy()
cube_chunks = cube.lazy_data().chunks
for factory in factories:
for coord in factory.dependencies.values():
coord_dims = cube.coord_dims(coord)
if coord_dims:
coord = coord.copy()
chunks = tuple(cube_chunks[i] for i in coord_dims)
coord.points = coord.lazy_points().rechunk(chunks)
if coord.has_bounds():
coord.bounds = coord.lazy_bounds().rechunk(
chunks + (None,)
)
cube.replace_coord(coord)
return cube
4 changes: 4 additions & 0 deletions esmvalcore/preprocessor/_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,10 @@ def timeseries_filter(
# Apply filter
(agg, agg_kwargs) = get_iris_aggregator(filter_stats, **operator_kwargs)
agg_kwargs["weights"] = wgts
if cube.has_lazy_data():
# Ensure the cube data chunktype is np.MaskedArray so rolling_window
# does not ignore a potential mask.
cube.data = da.ma.masked_array(cube.core_data())
cube = cube.rolling_window("time", agg, len(wgts), **agg_kwargs)

return cube
Expand Down
Loading

0 comments on commit d0bfb58

Please sign in to comment.