Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CDAT Migration Phase 2: Regression testing for lat_lon, lat_lon_land, and lat_lon_river #744

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9d5564c
Starter code for fixing integration tests
tomvothecoder Oct 16, 2023
a599ac1
Fix 3D variable bugs
tomvothecoder Oct 17, 2023
728b7b8
Fix regridding tests due to bounds not being converted
tomvothecoder Oct 18, 2023
27db54d
Uncomment integration test related lines
tomvothecoder Oct 18, 2023
d44742c
Try Union to fix integration tests
tomvothecoder Oct 18, 2023
b9865fc
Remove TODO comments
tomvothecoder Oct 18, 2023
2d46581
Fix for loop subsetting in `_run_3d_diags()`
tomvothecoder Oct 19, 2023
c923df2
Move call to for adding bounds to `regrid_z_axis_to_plevs()`
tomvothecoder Oct 19, 2023
ee185ed
Add comments in `test_all_sets.py`
tomvothecoder Oct 19, 2023
ff6145e
Fix type annotation support for pipe in py3.9
tomvothecoder Oct 19, 2023
417f92f
update comment
tomvothecoder Oct 19, 2023
f7c2b17
Remove repeated call to save plots and metrics
tomvothecoder Oct 19, 2023
21f3bfc
Fix incorrect import of `_get_output_dir()
tomvothecoder Oct 19, 2023
1371172
Revert VS Code debugger settings
tomvothecoder Oct 23, 2023
5e65378
Apply suggestions from code review
tomvothecoder Oct 23, 2023
5811454
Clean up `test_diags.py`
tomvothecoder Oct 23, 2023
79f04d4
Update order of methods in `test_diags.py`
tomvothecoder Oct 23, 2023
0f77d92
Update fixture in `test_diags.py` to only run once
tomvothecoder Oct 23, 2023
b367f61
Update order of methods
tomvothecoder Oct 23, 2023
0e66bc8
Fix conditional for bbox
tomvothecoder Oct 23, 2023
90e39ab
Uncomment pytest testspaths config
tomvothecoder Oct 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ docs: ## generate Sphinx HTML documentation, including API docs
# Build
# ----------------------
install: clean ## install the package to the active Python's site-packages
python setup.py install
python -m pip install .
50 changes: 24 additions & 26 deletions e3sm_diags/driver/lat_lon_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import json
import os
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple, Union

import xarray as xr

from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.driver.utils.general import get_output_dir
from e3sm_diags.driver.utils.io import _write_vars_to_netcdf
from e3sm_diags.driver.utils.io import _get_output_dir, _write_vars_to_netcdf
from e3sm_diags.driver.utils.regrid import (
_apply_land_sea_mask,
_subset_on_region,
Expand All @@ -27,7 +26,9 @@
# type of metrics and the value is a sub-dictionary of metrics (key is metrics
# type and value is float). There is also a "unit" key representing the
# units for the variable.
MetricsDict = Dict[str, str | Dict[str, float | None | List[float]]]
UnitAttr = str
MetricsSubDict = Dict[str, Union[float, None, List[float]]]
MetricsDict = Dict[str, Union[UnitAttr, MetricsSubDict]]

if TYPE_CHECKING:
from e3sm_diags.parameter.core_parameter import CoreParameter
Expand Down Expand Up @@ -116,7 +117,6 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:
ref_name,
)
elif is_vars_3d:
# TODO: Test this conditional with 3D variables.
_run_diags_3d(
parameter,
ds_test,
Expand Down Expand Up @@ -238,14 +238,13 @@ def _run_diags_3d(
plev = parameter.plevs
logger.info("Selected pressure level(s): {}".format(plev))

ds_test = regrid_z_axis_to_plevs(ds_test, var_key, parameter.plevs)
ds_ref = regrid_z_axis_to_plevs(ds_ref, var_key, parameter.plevs)
ds_test_rg = regrid_z_axis_to_plevs(ds_test, var_key, parameter.plevs)
ds_ref_rg = regrid_z_axis_to_plevs(ds_ref, var_key, parameter.plevs)

for ilev, _ in enumerate(plev):
# TODO: Test the subsetting here with 3D variables
z_axis = get_z_axis(ds_test[var_key])
ds_test_ilev = ds_test.isel({f"{z_axis}": ilev})
ds_ref_ilev = ds_ref.isel({f"{z_axis}": ilev})
for ilev in plev:
z_axis_key = get_z_axis(ds_test_rg[var_key]).name
ds_test_ilev = ds_test_rg.sel({z_axis_key: ilev})
ds_ref_ilev = ds_ref_rg.sel({z_axis_key: ilev})

for region in regions:
(
Expand Down Expand Up @@ -307,12 +306,15 @@ def _set_param_output_attrs(
The parameter with updated output attributes.
"""
if ilev is None:
parameter.output_file = f"{ref_name}-{var_key}-{season}-{region}"
parameter.main_title = f"{var_key} {season} {region}"
output_file = f"{ref_name}-{var_key}-{season}-{region}"
main_title = f"{var_key} {season} {region}"
else:
ilev_str = str(int(ilev))
parameter.output_file = f"{ref_name}-{var_key}-{ilev_str}-{season}-{region}"
parameter.main_title = f"{var_key} {ilev_str} 'mb' {season} {region}"
output_file = f"{ref_name}-{var_key}-{ilev_str}-{season}-{region}"
main_title = f"{var_key} {ilev_str} 'mb' {season} {region}"

parameter.output_file = output_file
parameter.main_title = main_title

return parameter

Expand Down Expand Up @@ -396,10 +398,6 @@ def _get_metrics_by_region(
var_key, ds_test, ds_test_regrid, ds_ref, ds_ref_regrid, ds_diff
)

_save_data_metrics_and_plots(
parameter, var_key, metrics_dict, ds_test, ds_ref, ds_diff
)

return metrics_dict, ds_test, ds_ref, ds_diff


Expand Down Expand Up @@ -556,14 +554,14 @@ def _save_data_metrics_and_plots(
ds_diff,
)

filename = os.path.join(
get_output_dir(parameter.current_set, parameter),
parameter.output_file + ".json",
)
with open(filename, "w") as outfile:
output_dir = _get_output_dir(parameter)
filename = f"{parameter.output_file}.json"
filepath = os.path.join(output_dir, filename)

with open(filepath, "w") as outfile:
json.dump(metrics_dict, outfile)

logger.info(f"Metrics saved in {filename}")
logger.info(f"Metrics saved in {filepath}")

# Set the viewer description to the "long_name" attr of the variable.
parameter.viewer_descr[var_key] = ds_test[var_key].attrs.get(
Expand Down
32 changes: 19 additions & 13 deletions e3sm_diags/driver/utils/dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,22 @@ def __init__(
"Valid options include 'ref' or 'test'."
)

# Set the `start_yr` and `end_yr` attrs based on the dataset type.
# Note, these attrs are different for the `area_mean_time_series`
# parameter.
if self.parameter.sets[0] in ["area_mean_time_series"]:
self.start_yr = self.parameter.start_yr # type: ignore
self.end_yr = self.parameter.end_yr # type: ignore
elif self.data_type == "ref":
self.start_yr = self.parameter.ref_start_yr # type: ignore
self.end_yr = self.parameter.ref_end_yr # type: ignore
elif self.data_type == "test":
self.start_yr = self.parameter.test_start_yr # type: ignore
self.end_yr = self.parameter.test_end_yr # type: ignore
# If the underlying data is a time series, set the `start_yr` and
# `end_yr` attrs based on the data type (ref or test). Note, these attrs
# are different for the `area_mean_time_series` parameter.
if self.is_time_series:
# FIXME: This conditional should not assume the first set is
# area_mean_time_series. If area_mean_time_series is at another
# index, this conditional is not False.
if self.parameter.sets[0] in ["area_mean_time_series"]:
self.start_yr = self.parameter.start_yr # type: ignore
self.end_yr = self.parameter.end_yr # type: ignore
elif self.data_type == "ref":
self.start_yr = self.parameter.ref_start_yr # type: ignore
self.end_yr = self.parameter.ref_end_yr # type: ignore
elif self.data_type == "test":
self.start_yr = self.parameter.test_start_yr # type: ignore
self.end_yr = self.parameter.test_end_yr # type: ignore

# The derived variables defined in E3SM Diags. If the `CoreParameter`
# object contains additional user derived variables, they are added
Expand Down Expand Up @@ -969,7 +973,9 @@ def _get_land_sea_mask(self, season: str) -> xr.Dataset:
ds_land_frac = self.get_climo_dataset(LAND_FRAC_KEY, season) # type: ignore
ds_ocean_frac = self.get_climo_dataset(OCEAN_FRAC_KEY, season) # type: ignore
except IOError as e:
logger.warning(e)
logger.warning(
f"{e}. Using default land sea mask located at `{LAND_OCEAN_MASK_PATH}`."
)

ds_mask = xr.open_dataset(LAND_OCEAN_MASK_PATH)
ds_mask = self._squeeze_time_dim(ds_mask)
Expand Down
101 changes: 75 additions & 26 deletions e3sm_diags/driver/utils/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,16 +359,22 @@ def regrid_z_axis_to_plevs(
Replaces `e3sm_diags.driver.utils.general.convert_to_pressure_levels`.
"""
ds = dataset.copy()

# Make sure that the input dataset has Z axis bounds, which are required for
# getting grid positions during vertical regridding.
try:
ds.bounds.get_bounds("Z")
except KeyError:
ds = ds.bounds.add_bounds("Z")

z_axis = get_z_axis(ds[var_key])
z_long_name = z_axis.attrs.get("long_name")

if z_long_name is None:
raise KeyError(
f"The vertical level ({z_axis.name}) for '{var_key}' does "
"not have a 'long_name' attribute to determine whether it is hybrid "
"or pressure."
)

z_long_name = z_long_name.lower()

# Hybrid must be the first conditional statement because the long_name attr
Expand All @@ -385,8 +391,11 @@ def regrid_z_axis_to_plevs(
"'pressure', or 'isobaric'."
)

# Add bounds for the new, regridded Z axis.
ds_plevs = ds_plevs.bounds.add_bounds(axis="Z")
# Add bounds for the new, regridded Z axis if the length is greater than 1.
# xCDAT does not support adding bounds for singleton coordinates.
new_z_axis = get_z_axis(ds_plevs[var_key])
if len(new_z_axis) > 1:
ds_plevs = ds_plevs.bounds.add_bounds("Z")

return ds_plevs

Expand Down Expand Up @@ -423,20 +432,12 @@ def _hybrid_to_plevs(
-----
Replaces `e3sm_diags.driver.utils.general.hybrid_to_plevs`.
"""
# TODO: Do we need to convert the Z axis to mb units if it is in PA? Or
# do we always expect units to be in mb?
# TODO: mb units are always expected, but we should consider checking
# the units to confirm whether or not unit conversion is needed.
z_axis, _ = xc.create_axis("lev", plevs, generate_bounds=False)

pressure_grid = xc.create_grid(z=z_axis)
pressure_coords = _hybrid_to_pressure(ds, var_key)

# Make sure that the input dataset has Z axis bounds, which are required for
# getting grid positions during vertical regridding.
try:
ds.bounds.get_bounds("Z")
except KeyError:
ds = ds.bounds.add_bounds("Z")

# Keep the "axis" and "coordinate" attributes for CF mapping.
with xr.set_options(keep_attrs=True):
result = ds.regridder.vertical(
Expand Down Expand Up @@ -497,7 +498,7 @@ def _hybrid_to_pressure(ds: xr.Dataset, var_key: str) -> xr.DataArray:
"'hyam' and/or 'hybm' to use for reconstructing to pressure data."
)

ps = _convert_units_to_mb(ps)
ps = _convert_dataarray_units_to_mb(ps)

pressure_coords = hyam * p0 + hybm * ps
pressure_coords.attrs["units"] = "mb"
Expand Down Expand Up @@ -564,14 +565,13 @@ def _pressure_to_plevs(
-----
Replaces `e3sm_diags.driver.utils.general.pressure_to_plevs`.
"""
# Convert pressure coordinates and bounds to mb if it is not already in mb.
ds = _convert_dataset_units_to_mb(ds, var_key)

# Create the output pressure grid to regrid to using the `plevs` array.
z_axis, _ = xc.create_axis("lev", plevs, generate_bounds=False)
pressure_grid = xc.create_grid(z=z_axis)

# Convert pressure coordinates to mb if it is not already in mb.
lev_key = xc.get_dim_keys(ds[var_key], axis="Z")
ds[lev_key] = _convert_units_to_mb(ds[lev_key])

# Keep the "axis" and "coordinate" attributes for CF mapping.
with xr.set_options(keep_attrs=True):
result = ds.regridder.vertical(
Expand All @@ -584,10 +584,57 @@ def _pressure_to_plevs(
return result


def _convert_units_to_mb(da: xr.DataArray) -> xr.DataArray:
"""Convert DataArray to mb (millibars) if not in mb.
def _convert_dataset_units_to_mb(ds: xr.Dataset, var_key: str) -> xr.Dataset:
"""Convert a dataset's Z axis and bounds to mb if they are not in mb.

Parameters
----------
ds : xr.Dataset
The dataset.
var_key : str
The key of the variable.

Returns
-------
xr.Dataset
The dataset with a Z axis in mb units.

Raises
------
RuntimeError
If the Z axis units does not align with the Z bounds units.
"""
z_axis = xc.get_dim_coords(ds[var_key], axis="Z")
z_bnds = ds.bounds.get_bounds(axis="Z", var_key=var_key)

# Make sure that Z and Z bounds units are aligned. If units do not exist
# assume they are the same because bounds usually don't have a units attr.
z_axis_units = z_axis.attrs["units"]
z_bnds_units = z_bnds.attrs.get("units")
if z_bnds_units is not None and z_bnds_units != z_axis_units:
raise RuntimeError(
f"The units for '{z_bnds.name}' ({z_bnds_units}) "
f"does not align with '{z_axis.name}' ({z_axis_units}). "
)
else:
z_bnds.attrs["units"] = z_axis_units

# Convert Z and Z bounds and update them in the Dataset.
z_axis_new = _convert_dataarray_units_to_mb(z_axis)
ds = ds.assign_coords({z_axis.name: z_axis_new})

z_bnds_new = _convert_dataarray_units_to_mb(z_bnds)
z_bnds_new[z_axis.name] = z_axis_new
ds[z_bnds.name] = z_bnds_new

return ds


def _convert_dataarray_units_to_mb(da: xr.DataArray) -> xr.DataArray:
"""Convert a dataarray to mb (millibars) if they are not in mb.

Unit conversion formulas:
* hPa = mb
* mb = Pa / 100
* Pa = (mb * 100)

Expand All @@ -614,17 +661,19 @@ def _convert_units_to_mb(da: xr.DataArray) -> xr.DataArray:

if units is None:
raise ValueError(
"'{ps.name}' has no 'units' attribute to determine if data is in 'mb' or "
"'Pa' units."
f"'{da.name}' has no 'units' attribute to determine if data is in'mb', "
"'hPa', or 'Pa' units."
)

if units == "mb":
pass
elif units == "Pa":
if units == "Pa":
with xr.set_options(keep_attrs=True):
da = da / 100.0

da.attrs["units"] = "mb"
elif units == "hPa":
da.attrs["units"] = "mb"
elif units == "mb":
pass
else:
raise ValueError(
f"'{da.name}' should be in 'mb' or 'Pa' (which gets converted to 'mb'), "
Expand Down
1 change: 1 addition & 0 deletions e3sm_diags/plot/cartopy/arm_diags_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def plot_convection_onset_statistics(
var_time_absolute = cwv.getTime().asComponentTime()
time_interval = int(var_time_absolute[1].hour - var_time_absolute[0].hour)

# FIXME: UnboundLocalError: local variable 'cwv_max' referenced before assignment
number_of_bins = int(np.ceil((cwv_max - cwv_min) / bin_width))
bin_center = np.arange(
(cwv_min + (bin_width / 2)),
Expand Down
4 changes: 4 additions & 0 deletions e3sm_diags/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def get_final_parameters(self, parameters):
"""
Based on sets_to_run and the list of parameters,
get the final list of paremeters to run the diags on.
FIXME: This function was only designed to take in 1 parameter at a
time or a mix of different parameters. If there are two
CoreParameter objects, it will break.
"""
if not parameters or not isinstance(parameters, list):
msg = "You must pass in a list of parameter objects."
Expand Down
3 changes: 2 additions & 1 deletion tests/e3sm_diags/driver/utils/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,11 @@ def test_regrids_pressure_coordinates_to_pressure_levels_with_Pa_units(
attrs={"xcdat_bounds": "True"},
)

# Update from Pa to mb.
# Update mb to Pa so this test can make sure conversions to mb are done.
ds_pa = ds.copy()
with xr.set_options(keep_attrs=True):
ds_pa["lev"] = ds_pa.lev * 100
ds_pa["lev_bnds"] = ds_pa.lev_bnds * 100
ds_pa.lev.attrs["units"] = "Pa"

result = regrid_z_axis_to_plevs(ds_pa, "so", self.plevs)
Expand Down
Loading