Skip to content

Commit

Permalink
Update new climo() with DataArray parsing
Browse files Browse the repository at this point in the history
- Add FIXME comment in old `climo()` that breaks with "ANN" season
  • Loading branch information
tomvothecoder committed Mar 22, 2023
1 parent 9054cac commit 8882c38
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 52 deletions.
20 changes: 11 additions & 9 deletions e3sm_diags/driver/lat_lon_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901

# Get land/ocean fraction for masking.
try:
land_frac = test_data.get_climo_variable("LANDFRAC", season)
ocean_frac = test_data.get_climo_variable("OCNFRAC", season)
# FIXME: Capture the exact exceptions, not the general Exception (bad practice).
except Exception:
land_frac = test_data.get_climo_variable("LANDFRAC", season) # type: ignore
ocean_frac = test_data.get_climo_variable("OCNFRAC", season) # type: ignore
except RuntimeError as e:
logger.warning(e)

mask_path = os.path.join(
e3sm_diags.INSTALL_PATH, "acme_ne30_ocean_land_mask.nc"
)
Expand All @@ -144,14 +145,15 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901
land_frac = ds_mask["LANDFRAC"]
ocean_frac = ds_mask["OCNFRAC"]

# TODO: Now we are here.
parameter.model_only = False
for var in variables:
logger.info("Variable: {}".format(var))
parameter.var_id = var

mv1 = test_data.get_climo_variable(var, season)
mv1 = test_data.get_climo_variable(var, season) # type: ignore
try:
mv2 = ref_data.get_climo_variable(var, season)
mv2 = ref_data.get_climo_variable(var, season) # type: ignore
except (RuntimeError, IOError):
mv2 = mv1
logger.info("Can not process reference data, analyse test data only")
Expand All @@ -165,7 +167,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901
)

# For variables with a z-axis.
if mv1.getLevel() and mv2.getLevel():
if mv1.getLevel() and mv2.getLevel(): # type: ignore
plev = parameter.plevs
logger.info("Selected pressure level: {}".format(plev))

Expand Down Expand Up @@ -221,7 +223,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901
)

# For variables without a z-axis.
elif mv1.getLevel() is None and mv2.getLevel() is None:
elif mv1.getLevel() is None and mv2.getLevel() is None: # type: ignore
for region in regions:
parameter.var_region = region

Expand All @@ -243,4 +245,4 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901
"Dimensions of the two variables are different. Aborting."
)

return parameter
return parameter
2 changes: 2 additions & 0 deletions e3sm_diags/driver/utils/climo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def climo(var, season):
v = var.asma()

# Compute climatology
# FIXME: season can be "ANN" and there is no conditional for "ANN", which
# breaks this function.
if season == "ANNUALCYCLE":
cycle = [
"01",
Expand Down
54 changes: 34 additions & 20 deletions e3sm_diags/driver/utils/climo_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,49 @@
import xarray as xr
import xcdat as xc

CDAT_TO_XCDAT_SEASON_FREQ = {"ANNUALCYCLE": "month", "SEASONALCYCLE": "year"}
CLIMO_FREQ = Literal["ANN", "ANNUALCYCLE", "SEASONALCYCLE", "DJF", "MAM", "JJA", "SON"]
CDAT_TO_XCDAT_SEASON_FREQ = {
"ANN": "month",
"ANNUALCYCLE": "month",
"SEASONALCYCLE": "year",
}


def climo(
ds: xr.Dataset,
data_var: str,
season: Literal["ANNUALCYCLE", "SEASONALCYCLE", "DJF", "MAM", "JJA", "SON"],
) -> xr.DataArray:
def climo(data_var: xr.DataArray, freq: CLIMO_FREQ) -> xr.DataArray:
"""Computes a variable's climatology for the given season.
xCDAT's climatology API uses time bounds to redefine time as the midpoint
between bounds values and month lengths for proper weighting.
Parameters
----------
data_var : xr.Dataset
The dataset containing the data variable and time bounds.
data_var : str
The name of the data variable to calculate climatology for.
season : Literal["ANNUALCYCLE", "SEASONALCYCLE", "DJF", "MAM", "JJA", "SON"]
The climatology season.
data_var : xr.DataArray
The data variable.
freq : CLIMO_FREQ
The frequency for calculating climatology
Returns
-------
xr.DataArray
The variables' climatology
"""
var_time = xc.get_dim_coords(data_var, axis="T")

if season in ["ANNUALCYCLE", "SEASONALCYCLE"]:
freq = CDAT_TO_XCDAT_SEASON_FREQ[season]
ds_climo = ds.temporal.climatology(data_var, freq=freq)
# Open the data variable's dataset to use xCDAT's climatology API, which
# operates on xr.Dataset objects.
filepath = data_var.encoding["source"]
ds = xr.open_dataset(filepath)
dv_key = data_var.name

if freq in ["ANN", "ANNUALCYCLE", "SEASONALCYCLE"]:
xc_freq = CDAT_TO_XCDAT_SEASON_FREQ[freq]
ds_climo = ds.temporal.climatology(dv_key, freq=xc_freq)
else:
ds_climo = ds.temporal.climatology(data_var, freq="season")
ds_climo = ds_climo.sel(f"{var_time.name}.season" == season)
# Get the name of the time dimension and subset to the single season
# before calculating climatology. The general best practice for
# performance is to subset then perform calculations (split-group-apply
# paradigm).
time_dim = xc.get_dim_keys(data_var, axis="T")
ds = ds.isel({f"{time_dim}": (ds[time_dim].dt.season == freq)})

ds_climo = ds.temporal.climatology(dv_key, freq="season")

return ds_climo[data_var]
return ds_climo[dv_key]
38 changes: 15 additions & 23 deletions e3sm_diags/driver/utils/dataset_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@
import glob
import os
import re
from typing import List, Literal, Union
from typing import List, Union

import cdms2
import xarray as xr

from e3sm_diags.derivations.acme_new import derived_variables
from e3sm_diags.driver.utils.climo_new import climo
from e3sm_diags.driver.utils.climo_new import CLIMO_FREQ, climo
from e3sm_diags.driver.utils.general import adjust_time_from_time_bounds

SEASON = Literal["ANNUALCYCLE", "SEASONALCYCLE", "DJF", "MAM", "JJA", "SON"]


class Dataset:
def __init__(
Expand Down Expand Up @@ -141,19 +139,23 @@ def get_timeseries_variable(
return variables[0] if len(variables) == 1 else variables

def get_climo_variable(
self, var: str, season: SEASON, extra_vars: List[str] = [], *args, **kwargs
self, var: str, season: CLIMO_FREQ, extra_vars: List[str] = [], *args, **kwargs
) -> Union[xr.DataArray, List[xr.DataArray]]:
"""Get climatology variables from climatology datasets.
These variables can either be from the test data or reference data.
For a given season, get the variable and any extra variables and run
the climatology on them.
If the variable is a climatology variable then get it directly
from the dataset. If the variable is a time series variable, get the
variable from the dataset and compute the climatology.
Parameters
----------
var : str
The variable name.
season : SEASON
season : CLIMO_FREQ
The season for calculation climatology.
extra_vars : List[str], optional
Extra variables to run, by default [].
Expand Down Expand Up @@ -182,14 +184,13 @@ def get_climo_variable(
if not season:
raise RuntimeError("Season is invalid.")

# Get the climatology variable directly from the climatology dataset.
if self.is_climo():
if self.ref:
filename = self.get_ref_filename_climo(season)
elif self.test:
filename = self.get_test_filename_climo(season)
variables = self._get_climo_vars(filename)
# Compute the climatology using the variable in the timeseries dataset.

climo_vars = self._get_climo_vars(filename)
elif self.is_timeseries():
if self.ref:
data_path = self.parameters.reference_data_path
Expand All @@ -198,21 +199,15 @@ def get_climo_variable(

# FIXME: Bounds are not attached to the DataArray so we must pass
# the Dataset instead
ds = xr.open_dataset(data_path)
timeseries_vars = self._get_timeseries_vars(data_path, *args, **kwargs)
variables = [climo(ds, var.name, season) for var in timeseries_vars]

climo_vars = [climo(var, season) for var in timeseries_vars]
else:
msg = "Error when determining what kind (ref or test) "
msg += "of variable to get and where to get it from "
msg += "(climo or timeseries files)."
raise RuntimeError(msg)

# Needed so we can do:
# v1 = Dataset.get_variable('v1', season)
# and also:
# v1, v2, v3 = Dataset.get_variable('v1', season, extra_vars=['v2', 'v3'])
return variables[0] if len(variables) == 1 else variables
return climo_vars[0] if len(climo_vars) == 1 else climo_vars

def get_static_variable(self, static_var, primary_var):
# TODO: Refactor this method.
Expand Down Expand Up @@ -373,13 +368,10 @@ def _find_climo_file(self, path_name, data_name, season):
# No file found.
return ""

def _get_climo_vars(self, filename, extra_vars_only=False):
# TODO: Refactor this method.
"""
For a given season and climo input data,
get the variable (self.var).
def _get_climo_vars(self, filename: str, extra_vars_only: bool = False):
"""For a given season and climo input data, get the variable (self.var).
If self.extra_vars is also defined, get them as well.
If ``self.extra_vars`` is also defined, get them as well.
"""
vars_to_get = []
if not extra_vars_only:
Expand Down

0 comments on commit 8882c38

Please sign in to comment.