Skip to content

Commit

Permalink
Refactor logic for preserving coordinates with regrid2
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Dec 6, 2024
1 parent e2d259e commit a8732a8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 49 deletions.
12 changes: 0 additions & 12 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,18 +517,6 @@ def test_unknown_variable(self):
with pytest.raises(KeyError):
regridder.horizontal("unknown", self.coarse_2d_ds)

def test_raises_error_if_axis_name_for_dim_cannot_be_determined(self):
ds = self.coarse_2d_ds.copy()
ds["lat"].attrs["standard_name"] = "latitude"
ds["lat"].attrs.pop("axis")

regridder = regrid2.Regrid2Regridder(ds, self.fine_2d_ds)

with pytest.raises(
ValueError, match="Could not determine axis name for dimension"
):
regridder.horizontal("ts", ds)

@pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning")
def test_regrid_input_mask(self):
regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)
Expand Down
113 changes: 76 additions & 37 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import xarray as xr

from xcdat.axis import CF_ATTR_MAP, get_dim_keys
import xcdat as xc
from xcdat.axis import VAR_NAME_MAP, get_dim_keys
from xcdat.regridder.base import BaseRegridder, _preserve_bounds

# Spatial axes keys used to map to the axes in an input data variable to build
# the output variable.
VALID_SPATIAL_AXES_KEYS = ["X", "Y"] + VAR_NAME_MAP["X"] + VAR_NAME_MAP["Y"]


class Regrid2Regridder(BaseRegridder):
def __init__(
Expand Down Expand Up @@ -229,48 +234,87 @@ def _build_dataset(
input_grid: xr.Dataset,
output_grid: xr.Dataset,
) -> xr.Dataset:
input_data_var = ds[data_var]
"""Build a new xarray Dataset with the given output data and coordinates.
output_coords: dict[str, xr.DataArray] = {}
output_data_vars: dict[str, xr.DataArray] = {}
Parameters
----------
ds : xr.Dataset
The input dataset containing the data variable to be regridded.
data_var : str
The name of the data variable in the input dataset to be regridded.
output_data : np.ndarray
The regridded data to be included in the output dataset.
input_grid : xr.Dataset
The input grid dataset containing the original grid information.
output_grid : xr.Dataset
The output grid dataset containing the new grid information.
for dim in input_data_var.dims:
dim = str(dim)
Returns
-------
xr.Dataset
A new dataset containing the regridded data variable with updated
coordinates and attributes.
"""
dv_input = ds[data_var]

try:
axis_name = [
cf_axis for cf_axis, dims in ds.cf.axes.items() if dim in dims
][0]
except IndexError as e:
raise ValueError(
f"Could not determine axis name for dimension {dim}"
) from e

if axis_name in ["X", "Y"]:
output_coords[dim] = output_grid.cf[axis_name]
else:
output_coords[dim] = input_data_var.cf[axis_name]
output_coords = _get_output_coords(dv_input, output_grid)

output_da = xr.DataArray(
output_data,
dims=input_data_var.dims,
dims=dv_input.dims,
coords=output_coords,
attrs=ds[data_var].attrs.copy(),
name=data_var,
)

output_data_vars[data_var] = output_da

output_ds = xr.Dataset(
output_data_vars,
attrs=input_grid.attrs.copy(),
)

output_ds = output_da.to_dataset()
output_ds.attrs = input_grid.attrs.copy()
output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"])

return output_ds


def _get_output_coords(
dv_input: xr.DataArray, output_grid: xr.Dataset
) -> Dict[str, xr.DataArray]:
"""
Generate the output coordinates for regridding based on the input data
variable and output grid.
Parameters
----------
dv_input : xr.DataArray
The input data variable containing the original coordinates.
output_grid : xr.Dataset
The dataset containing the target grid coordinates.
Returns
-------
Dict[str, xr.DataArray]
A dictionary where keys are coordinate names and values are the
corresponding coordinates from the output grid or input data variable,
aligned with the dimensions of the input data variable.
"""
output_coords: Dict[str, xr.DataArray] = {}

# First get the X and Y axes from the output grid.
for key in ["X", "Y"]:
input_coord = xc.get_dim_coords(dv_input, key) # type: ignore
output_coord = xc.get_dim_coords(output_grid, key) # type: ignore

output_coords[str(input_coord.name)] = output_coord # type: ignore

# Get the remaining axes the input data variable (e.g., "time").
for dim in dv_input.dims:
if dim not in output_coords:
output_coords[str(dim)] = dv_input[dim]

# Sort the coords to align with the input data variable dims.
output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims}

return output_coords


def _map_latitude(
src: np.ndarray, dst: np.ndarray
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
Expand Down Expand Up @@ -564,17 +608,12 @@ def _get_dimension(input_data_var, cf_axis_name):


def _get_bounds_ensure_dtype(ds, axis):
cf_keys = CF_ATTR_MAP[axis].values()

bounds = None

for key in cf_keys:
try:
name = ds.cf.bounds[key][0]
except (KeyError, IndexError):
pass
else:
bounds = ds[name]
try:
bounds = ds.bounds.get_bounds(axis)
except KeyError:
pass

if bounds is None:
raise RuntimeError(f"Could not determine {axis!r} bounds")
Expand Down

0 comments on commit a8732a8

Please sign in to comment.