From 8ca24c3e31e06c753742e158a02525ed17042f86 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Wed, 22 Mar 2023 11:53:19 -0700 Subject: [PATCH] Update new `climo()` with DataArray parsing - Add FIXME comment in old `climo()` that breaks with "ANN" season --- e3sm_diags/driver/lat_lon_driver.py | 20 +++++----- e3sm_diags/driver/utils/climo.py | 4 ++ e3sm_diags/driver/utils/climo_new.py | 54 ++++++++++++++++---------- e3sm_diags/driver/utils/dataset_new.py | 38 +++++++----------- 4 files changed, 64 insertions(+), 52 deletions(-) diff --git a/e3sm_diags/driver/lat_lon_driver.py b/e3sm_diags/driver/lat_lon_driver.py index 25a1a5acbd..b57b60fb71 100755 --- a/e3sm_diags/driver/lat_lon_driver.py +++ b/e3sm_diags/driver/lat_lon_driver.py @@ -132,10 +132,11 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901 # Get land/ocean fraction for masking. try: - land_frac = test_data.get_climo_variable("LANDFRAC", season) - ocean_frac = test_data.get_climo_variable("OCNFRAC", season) - # FIXME: Capture the exact exceptions, not the general Exception (bad practice). - except Exception: + land_frac = test_data.get_climo_variable("LANDFRAC", season) # type: ignore + ocean_frac = test_data.get_climo_variable("OCNFRAC", season) # type: ignore + except RuntimeError as e: + logger.warning(e) + mask_path = os.path.join( e3sm_diags.INSTALL_PATH, "acme_ne30_ocean_land_mask.nc" ) @@ -144,14 +145,15 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901 land_frac = ds_mask["LANDFRAC"] ocean_frac = ds_mask["OCNFRAC"] + # TODO: Now we are here. parameter.model_only = False for var in variables: logger.info("Variable: {}".format(var)) parameter.var_id = var - mv1 = test_data.get_climo_variable(var, season) + mv1 = test_data.get_climo_variable(var, season) # type: ignore try: - mv2 = ref_data.get_climo_variable(var, season) + mv2 = ref_data.get_climo_variable(var, season) # type: ignore except (RuntimeError, IOError): mv2 = mv1 logger.info("Can not process reference data, analyse test data only") @@ -165,7 +167,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901 ) # For variables with a z-axis. - if mv1.getLevel() and mv2.getLevel(): + if mv1.getLevel() and mv2.getLevel(): # type: ignore plev = parameter.plevs logger.info("Selected pressure level: {}".format(plev)) @@ -221,7 +223,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901 ) # For variables without a z-axis. - elif mv1.getLevel() is None and mv2.getLevel() is None: + elif mv1.getLevel() is None and mv2.getLevel() is None: # type: ignore for region in regions: parameter.var_region = region @@ -243,4 +245,4 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901 "Dimensions of the two variables are different. Aborting." ) - return parameter \ No newline at end of file + return parameter diff --git a/e3sm_diags/driver/utils/climo.py b/e3sm_diags/driver/utils/climo.py index b05ba7f236..3a506b11fd 100644 --- a/e3sm_diags/driver/utils/climo.py +++ b/e3sm_diags/driver/utils/climo.py @@ -45,6 +45,10 @@ def climo(var, season): v = var.asma() # Compute climatology + # FIXME: season can be "ANN" and there is no conditional for "ANN", which + # breaks this function. The old Dataset class would call this function + # and drivers would do a try and except that captures Exception (all + # exceptions), which produces a silent error. if season == "ANNUALCYCLE": cycle = [ "01", diff --git a/e3sm_diags/driver/utils/climo_new.py b/e3sm_diags/driver/utils/climo_new.py index ab65b6f6b6..008984dfe8 100644 --- a/e3sm_diags/driver/utils/climo_new.py +++ b/e3sm_diags/driver/utils/climo_new.py @@ -3,14 +3,15 @@ import xarray as xr import xcdat as xc -CDAT_TO_XCDAT_SEASON_FREQ = {"ANNUALCYCLE": "month", "SEASONALCYCLE": "year"} +CLIMO_FREQ = Literal["ANN", "ANNUALCYCLE", "SEASONALCYCLE", "DJF", "MAM", "JJA", "SON"] +CDAT_TO_XCDAT_SEASON_FREQ = { + "ANN": "month", + "ANNUALCYCLE": "month", + "SEASONALCYCLE": "year", +} -def climo( - ds: xr.Dataset, - data_var: str, - season: Literal["ANNUALCYCLE", "SEASONALCYCLE", "DJF", "MAM", "JJA", "SON"], -) -> xr.DataArray: +def climo(data_var: xr.DataArray, freq: CLIMO_FREQ) -> xr.DataArray: """Computes a variable's climatology for the given season. xCDAT's climatology API uses time bounds to redefine time as the midpoint @@ -18,20 +19,33 @@ def climo( Parameters ---------- - data_var : xr.Dataset - The dataset containing the data variable and time bounds. - data_var : str - The name of the data variable to calculate climatology for. - season : Literal["ANNUALCYCLE", "SEASONALCYCLE", "DJF", "MAM", "JJA", "SON"] - The climatology season. + data_var : xr.DataArray + The data variable. + freq : CLIMO_FREQ + The frequency for calculating climatology + + Returns + ------- + xr.DataArray + The variables' climatology """ - var_time = xc.get_dim_coords(data_var, axis="T") - - if season in ["ANNUALCYCLE", "SEASONALCYCLE"]: - freq = CDAT_TO_XCDAT_SEASON_FREQ[season] - ds_climo = ds.temporal.climatology(data_var, freq=freq) + # Open the data variable's dataset to use xCDAT's climatology API, which + # operates on xr.Dataset objects. + filepath = data_var.encoding["source"] + ds = xr.open_dataset(filepath) + dv_key = data_var.name + + if freq in ["ANN", "ANNUALCYCLE", "SEASONALCYCLE"]: + xc_freq = CDAT_TO_XCDAT_SEASON_FREQ[freq] + ds_climo = ds.temporal.climatology(dv_key, freq=xc_freq) else: - ds_climo = ds.temporal.climatology(data_var, freq="season") - ds_climo = ds_climo.sel(f"{var_time.name}.season" == season) + # Get the name of the time dimension and subset to the single season + # before calculating climatology. The general best practice for + # performance is to subset then perform calculations (split-group-apply + # paradigm). + time_dim = xc.get_dim_keys(data_var, axis="T") + ds = ds.isel({f"{time_dim}": (ds[time_dim].dt.season == freq)}) + + ds_climo = ds.temporal.climatology(dv_key, freq="season") - return ds_climo[data_var] + return ds_climo[dv_key] diff --git a/e3sm_diags/driver/utils/dataset_new.py b/e3sm_diags/driver/utils/dataset_new.py index 59c57d3849..81e656ec89 100644 --- a/e3sm_diags/driver/utils/dataset_new.py +++ b/e3sm_diags/driver/utils/dataset_new.py @@ -8,17 +8,15 @@ import glob import os import re -from typing import List, Literal, Union +from typing import List, Union import cdms2 import xarray as xr from e3sm_diags.derivations.acme_new import derived_variables -from e3sm_diags.driver.utils.climo_new import climo +from e3sm_diags.driver.utils.climo_new import CLIMO_FREQ, climo from e3sm_diags.driver.utils.general import adjust_time_from_time_bounds -SEASON = Literal["ANNUALCYCLE", "SEASONALCYCLE", "DJF", "MAM", "JJA", "SON"] - class Dataset: def __init__( @@ -141,7 +139,7 @@ def get_timeseries_variable( return variables[0] if len(variables) == 1 else variables def get_climo_variable( - self, var: str, season: SEASON, extra_vars: List[str] = [], *args, **kwargs + self, var: str, season: CLIMO_FREQ, extra_vars: List[str] = [], *args, **kwargs ) -> Union[xr.DataArray, List[xr.DataArray]]: """Get climatology variables from climatology datasets. @@ -149,11 +147,15 @@ def get_climo_variable( For a given season, get the variable and any extra variables and run the climatology on them. + If the variable is a climatology variable then get it directly + from the dataset. If the variable is a time series variable, get the + variable from the dataset and compute the climatology. + Parameters ---------- var : str The variable name. - season : SEASON + season : CLIMO_FREQ The season for calculation climatology. extra_vars : List[str], optional Extra variables to run, by default []. @@ -182,14 +184,13 @@ def get_climo_variable( if not season: raise RuntimeError("Season is invalid.") - # Get the climatology variable directly from the climatology dataset. if self.is_climo(): if self.ref: filename = self.get_ref_filename_climo(season) elif self.test: filename = self.get_test_filename_climo(season) - variables = self._get_climo_vars(filename) - # Compute the climatology using the variable in the timeseries dataset. + + climo_vars = self._get_climo_vars(filename) elif self.is_timeseries(): if self.ref: data_path = self.parameters.reference_data_path @@ -198,21 +199,15 @@ def get_climo_variable( # FIXME: Bounds are not attached to the DataArray so we must pass # the Dataset instead - ds = xr.open_dataset(data_path) timeseries_vars = self._get_timeseries_vars(data_path, *args, **kwargs) - variables = [climo(ds, var.name, season) for var in timeseries_vars] - + climo_vars = [climo(var, season) for var in timeseries_vars] else: msg = "Error when determining what kind (ref or test) " msg += "of variable to get and where to get it from " msg += "(climo or timeseries files)." raise RuntimeError(msg) - # Needed so we can do: - # v1 = Dataset.get_variable('v1', season) - # and also: - # v1, v2, v3 = Dataset.get_variable('v1', season, extra_vars=['v2', 'v3']) - return variables[0] if len(variables) == 1 else variables + return climo_vars[0] if len(climo_vars) == 1 else climo_vars def get_static_variable(self, static_var, primary_var): # TODO: Refactor this method. @@ -373,13 +368,10 @@ def _find_climo_file(self, path_name, data_name, season): # No file found. return "" - def _get_climo_vars(self, filename, extra_vars_only=False): - # TODO: Refactor this method. - """ - For a given season and climo input data, - get the variable (self.var). + def _get_climo_vars(self, filename: str, extra_vars_only: bool = False): + """For a given season and climo input data, get the variable (self.var). - If self.extra_vars is also defined, get them as well. + If ``self.extra_vars`` is also defined, get them as well. """ vars_to_get = [] if not extra_vars_only: