From 272d7452295d202a917bc3106f2c6eab8c188ad1 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Tue, 10 Oct 2023 15:42:49 -0700 Subject: [PATCH] CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677) Refer to the PR for more information because the changelog is massive. --- auxiliary_tools/cdat_regression_test.py | 171 ++++++++++++++ conda-env/ci.yml | 5 + conda-env/dev.yml | 7 + e3sm_diags/derivations/derivations.py | 6 + e3sm_diags/driver/lat_lon_driver.py | 211 ++++++++++++++++++ e3sm_diags/driver/utils/dataset_xr.py | 30 +++ e3sm_diags/driver/utils/io.py | 10 + e3sm_diags/driver/utils/regrid.py | 65 ++++++ e3sm_diags/metrics/metrics.py | 20 ++ e3sm_diags/parameter/core_parameter.py | 4 + e3sm_diags/plot/lat_lon_plot.py | 15 ++ e3sm_diags/plot/utils.py | 18 ++ .../driver/utils/test_dataset_xr.py | 3 + tests/e3sm_diags/driver/utils/test_regrid.py | 11 + tests/e3sm_diags/metrics/test_metrics.py | 3 + 15 files changed, 579 insertions(+) create mode 100644 auxiliary_tools/cdat_regression_test.py diff --git a/auxiliary_tools/cdat_regression_test.py b/auxiliary_tools/cdat_regression_test.py new file mode 100644 index 0000000000..d2e669db38 --- /dev/null +++ b/auxiliary_tools/cdat_regression_test.py @@ -0,0 +1,171 @@ +# %% +""" +This script checks for regressions between the refactored and `main` branches +of a diagnostic set. + +How it works +------------ + It compares the absolute and relative differences between two sets of + `.json` files in two separate directories, one for the refactored code + and the other for the `main` branch. This script will generate an Excel file + containing: + 1. The raw metrics by each branch for each variable. + 2. The absolute and relative differences of each variable between branches. + 3. The highest relative differences (threshold > 2% difference) + +How to use +----------- + 1. mamba env create -f conda/dev-yml -n e3sm_diags_dev_ + 2. mamba activate e3sm_diags_dev_ + 3. Update `DEV_PATH` and `PROD_PATH` in `/auxiliary_tools/cdat_regression_test.py` + 4. python auxiliary_tools/cdat_regression_test.py + 5. Excel file generated in `/auxiliary_tools` + +Tips +----------- +Relative differences should be taken into consideration moreso than absolute +differences. + - Relative differences show the scale using a percentage unit. + - Absolute differences is just a raw number that doesn't factor in + floating point size (e.g., 100.00 vs. 0.0001), which can be misleading. +""" +import glob +import logging +import os +import time +from typing import List + +import pandas as pd + +log_format = ( + "%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s" +) +logging.basicConfig(format=log_format, filemode="w", level=logging.INFO) +logger = logging.getLogger(__name__) + +# TODO: Update DEV_RESULTS and PROD_RESULTS. +# ------------------------------------------------------------------------------ +DEV_PATH = "/global/cfs/cdirs/e3sm/www/vo13/examples_658/ex1_modTS_vs_modTS_3years/lat_lon/model_vs_model" +PRO_PATH = "/global/cfs/cdirs/e3sm/www/vo13/examples/ex1_modTS_vs_modTS_3years/lat_lon/model_vs_model" +# ------------------------------------------------------------------------------ + +if not os.path.exists(DEV_PATH): + raise ValueError(f"DEV_RESULTS path does not exist ({DEV_PATH})") +if not os.path.exists(PRO_PATH): + raise ValueError(f"PROD_RESULTS path does not exist ({PRO_PATH})") + +DEV_GLOB = sorted(glob.glob(DEV_PATH + "/*.json")) +PROD_GLOB = sorted(glob.glob(PRO_PATH + "/*.json")) + +TIME_STR = time.strftime("%Y%m%d-%H%M%S") +EXCEL_FILENAME = f"{TIME_STR}-metrics-diffs.xlsx" + + +def get_metrics(filepaths: List[str]) -> pd.DataFrame: + """Get the metrics using a glob of `.json` metric files in a directory. + + Parameters + ---------- + filepaths : List[str] + The filepaths for metrics `.json` files. + + Returns + ------- + pd.DataFrame + The DataFrame containing the metrics for all of the variables in + the results directory. + """ + metrics = [] + + for filepath in filepaths: + df = pd.read_json(filepath) + + filename = filepath.split("/")[-1] + var_key = filename.split("-")[1] + + # Add the variable key to the MultiIndex and update the index + # before stacking to make the DataFrame easier to parse. + multiindex = pd.MultiIndex.from_product([[var_key], [*df.index]]) + df = df.set_index(multiindex) + df.stack() + + metrics.append(df) + + df_final = pd.concat(metrics) + + # Reorder columns and drop "unit" column (string dtype breaks Pandas + # arithmetic). + df_final = df_final[["test", "ref", "test_regrid", "ref_regrid", "diff", "misc"]] + + return df_final + + +def get_diffs(df_a: pd.DataFrame, df_b: pd.DataFrame) -> pd.DataFrame: + """The metrics differences between two DataFrames. + + Parameters + ---------- + df_a : pd.DataFrame + The first DataFrame representing "actual" results (aka development). + df_b : pd.DataFrame + The second DataFrame representing "reference" results (aka production). + + Returns + ------- + pd.DataFrame + The DataFrame containing absolute and relative differences between + the metrics DataFrames. + """ + # Absolute difference: abs(actual - reference) + df_abs = abs(df_a - df_b) + df_abs = df_abs.add_suffix("_abs") + + # Relative difference: abs(actual - reference) / abs(actual) + df_rel = abs(df_a - df_b) / abs(df_a) + df_rel = df_rel.add_suffix("_rel") + + # Combine both DataFrames + df_final = pd.concat([df_abs, df_rel], axis=1, join="outer") + + return df_final + + +# %% Get the metrics DataFrames. +df_dev = get_metrics(DEV_GLOB) +df_prod = get_metrics(PROD_GLOB) + +# %% Combine metrics DataFrames. +df_dev_pref = df_dev.add_prefix("dev_") +df_prod_pref = df_prod.add_prefix("prod_") +df_metrics = pd.concat([df_dev_pref, df_prod_pref], axis=1, join="outer") +#%% +# Sort the columns +df_metrics = df_metrics[ + [ + "dev_test", + "prod_test", + "dev_ref", + "prod_ref", + "dev_test_regrid", + "prod_test_regrid", + "dev_ref_regrid", + "prod_ref_regrid", + "dev_diff", + "prod_diff", + "dev_misc", + "prod_misc", + ] +] + +# %% Get differences between metrics. +df_diffs = get_diffs(df_dev, df_prod) + + +#%% +with pd.ExcelWriter(EXCEL_FILENAME) as writer: + df_metrics.to_excel(writer, sheet_name="metrics") + df_diffs.to_excel(writer, sheet_name="metric_diffs") + + +# %% Only get the metrics where the absolute and relative differences are +# greater than a specific threshold (>1%) diff --git a/conda-env/ci.yml b/conda-env/ci.yml index 18035f0989..5fcd13aa17 100644 --- a/conda-env/ci.yml +++ b/conda-env/ci.yml @@ -26,9 +26,14 @@ dependencies: - numpy >=1.23.0 - shapely >=2.0.0,<3.0.0 - xarray >=2023.02.0 +<<<<<<< HEAD - xcdat >=0.6.0 - xesmf >=0.7.0 - xskillscore >=0.0.20 +======= + - xcdat >=0.5.0 + - xesmf >=0.7.0 +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) # Testing # ================== - scipy diff --git a/conda-env/dev.yml b/conda-env/dev.yml index 33cd201fdb..fb55d28313 100644 --- a/conda-env/dev.yml +++ b/conda-env/dev.yml @@ -1,6 +1,7 @@ # Conda development environment for testing local source code changes to `e3sm_diags` before merging them to production (`master` branch). name: e3sm_diags_dev channels: + - conda-forge/label/xcdat_dev - conda-forge - defaults dependencies: @@ -24,9 +25,15 @@ dependencies: - numpy >=1.23.0 - shapely >=2.0.0,<3.0.0 - xarray >=2023.02.0 +<<<<<<< HEAD - xcdat >=0.6.0 - xesmf >=0.7.0 - xskillscore >=0.0.20 +======= + - xskillscore >=0.0.20 + - xcdat==0.6.0rc1 + - xesmf >=0.7.0 +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) # Testing # ======================= - scipy diff --git a/e3sm_diags/derivations/derivations.py b/e3sm_diags/derivations/derivations.py index ef29ae4711..79b0701509 100644 --- a/e3sm_diags/derivations/derivations.py +++ b/e3sm_diags/derivations/derivations.py @@ -8,9 +8,15 @@ For example to derive 'PRECT': 1. In `DERIVED_VARIABLE` there is an entry for 'PRECT'. 2. The netCDF file does not have a 'PRECT' variable, but has the 'PRECC' +<<<<<<< HEAD and 'PRECT' variables. 3. 'PRECC' and 'PRECL' are used to derive `PRECT` by passing the data for these variables to the formula function 'prect()'. +======= + and 'PRECT' variables. + 3. 'PRECC' and 'PRECL' are used to derive `PRECT` by passing the + data for these variables to the formula function 'prect()'. +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) """ from collections import OrderedDict from typing import Callable, Dict, Tuple diff --git a/e3sm_diags/driver/lat_lon_driver.py b/e3sm_diags/driver/lat_lon_driver.py index bea50fb944..37be5a0908 100755 --- a/e3sm_diags/driver/lat_lon_driver.py +++ b/e3sm_diags/driver/lat_lon_driver.py @@ -1,11 +1,22 @@ from __future__ import annotations +<<<<<<< HEAD from typing import TYPE_CHECKING, List, Tuple +======= +import json +import os +from typing import TYPE_CHECKING, Dict, List, Tuple +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) import xarray as xr from e3sm_diags.driver.utils.dataset_xr import Dataset +<<<<<<< HEAD from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots +======= +from e3sm_diags.driver.utils.general import get_output_dir +from e3sm_diags.driver.utils.io import _write_vars_to_netcdf +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) from e3sm_diags.driver.utils.regrid import ( _apply_land_sea_mask, _subset_on_region, @@ -14,13 +25,25 @@ has_z_axis, regrid_z_axis_to_plevs, ) +<<<<<<< HEAD from e3sm_diags.driver.utils.type_annotations import MetricsDict from e3sm_diags.logger import custom_logger from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg, std from e3sm_diags.plot.lat_lon_plot import plot as plot_func +======= +from e3sm_diags.logger import custom_logger +from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg, std +from e3sm_diags.plot.lat_lon_plot import plot +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) logger = custom_logger(__name__) +# The type annotation for the metrics dictionary. The key is the +# type of metrics and the value is a sub-dictionary of metrics (key is metrics +# type and value is float). There is also a "unit" key representing the +# units for the variable. +MetricsDict = Dict[str, str | Dict[str, float | None | List[float]]] + if TYPE_CHECKING: from e3sm_diags.parameter.core_parameter import CoreParameter @@ -63,7 +86,12 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: parameter.var_id = var_key for season in seasons: +<<<<<<< HEAD parameter._set_name_yrs_attrs(test_ds, ref_ds, season) +======= + parameter.test_name_yrs = test_ds.get_name_yrs_attr(season) + parameter.ref_name_yrs = ref_ds.get_name_yrs_attr(season) +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) # The land sea mask dataset that is used for masking if the region # is either land or sea. This variable is instantiated here to get @@ -71,6 +99,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask(season) ds_test = test_ds.get_climo_dataset(var_key, season) +<<<<<<< HEAD ds_ref = ref_ds.get_ref_climo_dataset(var_key, season, ds_test) # Store the variable's DataArray objects for reuse. @@ -80,6 +109,32 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: is_vars_3d = has_z_axis(dv_test) and has_z_axis(dv_ref) is_dims_diff = has_z_axis(dv_test) != has_z_axis(dv_ref) +======= + + # If the reference climatology dataset cannot be retrieved + # it will be set the to the test climatology dataset which means + # analysis is only performed on the test dataset. + # TODO: This logic was carried over from legacy implementation. It + # can probably be improved on by setting `ds_ref = None` and not + # performing unnecessary operations on `ds_ref` for model-only runs, + # since it is the same as `ds_test``. + try: + ds_ref = ref_ds.get_climo_dataset(var_key, season) + parameter.model_only = False + except (RuntimeError, IOError): + ds_ref = ds_test + parameter.model_only = True + + logger.info("Cannot process reference data, analyzing test data only.") + + # Store the variable's DataArray objects for reuse. + dv_test = ds_test[var_key] + dv_ref = ds_ref[var_key] + + is_vars_3d = has_z_axis(dv_test) and has_z_axis(dv_ref) + is_dims_diff = has_z_axis(dv_test) != has_z_axis(dv_ref) + +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) if not is_vars_3d: _run_diags_2d( parameter, @@ -92,6 +147,10 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: ref_name, ) elif is_vars_3d: +<<<<<<< HEAD +======= + # TODO: Test this conditional with 3D variables. +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) _run_diags_3d( parameter, ds_test, @@ -148,8 +207,14 @@ def _run_diags_2d( The reference name. """ for region in regions: +<<<<<<< HEAD parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None) +======= + parameter = _set_param_output_attrs( + parameter, var_key, season, region, ref_name, ilev=None + ) +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) ( metrics_dict, ds_test_region, @@ -165,12 +230,20 @@ def _run_diags_2d( ) _save_data_metrics_and_plots( parameter, +<<<<<<< HEAD plot_func, var_key, ds_test_region, ds_ref_region, ds_diff_region, metrics_dict, +======= + var_key, + metrics_dict, + ds_test_region, + ds_ref_region, + ds_diff_region, +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) ) @@ -213,6 +286,7 @@ def _run_diags_3d( plev = parameter.plevs logger.info("Selected pressure level(s): {}".format(plev)) +<<<<<<< HEAD ds_test_rg = regrid_z_axis_to_plevs(ds_test, var_key, parameter.plevs) ds_ref_rg = regrid_z_axis_to_plevs(ds_ref, var_key, parameter.plevs) @@ -220,6 +294,16 @@ def _run_diags_3d( z_axis_key = get_z_axis(ds_test_rg[var_key]).name ds_test_ilev = ds_test_rg.sel({z_axis_key: ilev}) ds_ref_ilev = ds_ref_rg.sel({z_axis_key: ilev}) +======= + ds_test = regrid_z_axis_to_plevs(ds_test, var_key, parameter.plevs) + ds_ref = regrid_z_axis_to_plevs(ds_ref, var_key, parameter.plevs) + + for ilev, _ in enumerate(plev): + # TODO: Test the subsetting here with 3D variables + z_axis = get_z_axis(ds_test[var_key]) + ds_test_ilev = ds_test.isel({f"{z_axis}": ilev}) + ds_ref_ilev = ds_ref.isel({f"{z_axis}": ilev}) +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) for region in regions: ( @@ -236,6 +320,7 @@ def _run_diags_3d( region, ) +<<<<<<< HEAD parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev) _save_data_metrics_and_plots( parameter, @@ -248,6 +333,63 @@ def _run_diags_3d( ) +======= + parameter = _set_param_output_attrs( + parameter, var_key, season, region, ref_name, ilev + ) + _save_data_metrics_and_plots( + parameter, + var_key, + metrics_dict, + ds_test_region, + ds_ref_region, + ds_diff_region, + ) + + +def _set_param_output_attrs( + parameter: CoreParameter, + var_key: str, + season: str, + region: str, + ref_name: str, + ilev: float | None, +) -> CoreParameter: + """Set the parameter output attributes based on argument values. + + Parameters + ---------- + parameter : CoreParameter + The parameter. + var_key : str + The variable key. + season : str + The season. + region : str + The region. + ref_name : str + The reference name, + ilev : float | None + The pressure level, by default None. This option is only set if the + variable is 3D. + + Returns + ------- + CoreParameter + The parameter with updated output attributes. + """ + if ilev is None: + parameter.output_file = f"{ref_name}-{var_key}-{season}-{region}" + parameter.main_title = f"{var_key} {season} {region}" + else: + ilev_str = str(int(ilev)) + parameter.output_file = f"{ref_name}-{var_key}-{ilev_str}-{season}-{region}" + parameter.main_title = f"{var_key} {ilev_str} 'mb' {season} {region}" + + return parameter + + +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) def _get_metrics_by_region( parameter: CoreParameter, ds_test: xr.Dataset, @@ -327,6 +469,13 @@ def _get_metrics_by_region( var_key, ds_test, ds_test_regrid, ds_ref, ds_ref_regrid, ds_diff ) +<<<<<<< HEAD +======= + _save_data_metrics_and_plots( + parameter, var_key, metrics_dict, ds_test, ds_ref, ds_diff + ) + +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) return metrics_dict, ds_test, ds_ref, ds_diff @@ -445,3 +594,65 @@ def _create_metrics_dict( } return metrics_dict +<<<<<<< HEAD +======= + + +def _save_data_metrics_and_plots( + parameter: CoreParameter, + var_key: str, + metrics_dict: MetricsDict, + ds_test: xr.Dataset, + ds_ref: xr.Dataset | None, + ds_diff: xr.Dataset | None, +): + """Save data (optional), metrics, and plots. + + Parameters + ---------- + parameter : CoreParameter + The parameter for the diagnostic. + var_key : str + The variable key. + metrics_dict : Metrics + The dictionary containing metrics for the variable. + ds_test : xr.Dataset + The test dataset. + ds_ref : xr.Dataset | None + The optional reference dataset. If the diagnostic is a model-only run, + then it will be None. + ds_diff : xr.Dataset | None + The optional difference dataset. If the diagnostic is a model-only run, + then it will be None. + """ + if parameter.save_netcdf: + _write_vars_to_netcdf( + parameter, + var_key, + ds_test, + ds_ref, + ds_diff, + ) + + filename = os.path.join( + get_output_dir(parameter.current_set, parameter), + parameter.output_file + ".json", + ) + with open(filename, "w") as outfile: + json.dump(metrics_dict, outfile) + + logger.info(f"Metrics saved in {filename}") + + # Set the viewer description to the "long_name" attr of the variable. + parameter.viewer_descr[var_key] = ds_test[var_key].attrs.get( + "long_name", "No long_name attr in test data" + ) + + plot( + ds_test[var_key], + ds_ref[var_key] if ds_ref is not None else None, + ds_diff[var_key] if ds_diff is not None else None, + metrics_dict, + parameter, + ) +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index 251e56c493..169d96ecfd 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -16,7 +16,11 @@ import glob import os import re +<<<<<<< HEAD from typing import TYPE_CHECKING, Callable, Dict, Literal, Tuple +======= +from typing import Callable, Dict, Literal, Tuple +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) import xarray as xr import xcdat as xc @@ -29,10 +33,14 @@ from e3sm_diags.driver import LAND_FRAC_KEY, LAND_OCEAN_MASK_PATH, OCEAN_FRAC_KEY from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQ, CLIMO_FREQS, climo from e3sm_diags.logger import custom_logger +<<<<<<< HEAD if TYPE_CHECKING: from e3sm_diags.parameter.core_parameter import CoreParameter +======= +from e3sm_diags.parameter.core_parameter import CoreParameter +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) logger = custom_logger(__name__) @@ -64,6 +72,7 @@ def __init__( "Valid options include 'ref' or 'test'." ) +<<<<<<< HEAD # If the underlying data is a time series, set the `start_yr` and # `end_yr` attrs based on the data type (ref or test). Note, these attrs # are different for the `area_mean_time_series` parameter. @@ -80,6 +89,20 @@ def __init__( elif self.data_type == "test": self.start_yr = self.parameter.test_start_yr # type: ignore self.end_yr = self.parameter.test_end_yr # type: ignore +======= + # Set the `start_yr` and `end_yr` attrs based on the dataset type. + # Note, these attrs are different for the `area_mean_time_series` + # parameter. + if self.parameter.sets[0] in ["area_mean_time_series"]: + self.start_yr = self.parameter.start_yr # type: ignore + self.end_yr = self.parameter.end_yr # type: ignore + elif self.data_type == "ref": + self.start_yr = self.parameter.ref_start_yr # type: ignore + self.end_yr = self.parameter.ref_end_yr # type: ignore + elif self.data_type == "test": + self.start_yr = self.parameter.test_start_yr # type: ignore + self.end_yr = self.parameter.test_end_yr # type: ignore +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) # The derived variables defined in E3SM Diags. If the `CoreParameter` # object contains additional user derived variables, they are added @@ -259,6 +282,7 @@ def _get_global_attr_from_climo_dataset( # -------------------------------------------------------------------------- # Climatology related methods # -------------------------------------------------------------------------- +<<<<<<< HEAD def get_ref_climo_dataset( self, var_key: str, season: CLIMO_FREQ, ds_test: xr.Dataset ): @@ -311,6 +335,8 @@ def get_ref_climo_dataset( return ds_ref +======= +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) def get_climo_dataset(self, var: str, season: CLIMO_FREQ) -> xr.Dataset: """Get the dataset containing the climatology variable. @@ -1028,9 +1054,13 @@ def _get_land_sea_mask(self, season: str) -> xr.Dataset: ds_land_frac = self.get_climo_dataset(LAND_FRAC_KEY, season) # type: ignore ds_ocean_frac = self.get_climo_dataset(OCEAN_FRAC_KEY, season) # type: ignore except IOError as e: +<<<<<<< HEAD logger.info( f"{e}. Using default land sea mask located at `{LAND_OCEAN_MASK_PATH}`." ) +======= + logger.warning(e) +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) ds_mask = xr.open_dataset(LAND_OCEAN_MASK_PATH) ds_mask = self._squeeze_time_dim(ds_mask) diff --git a/e3sm_diags/driver/utils/io.py b/e3sm_diags/driver/utils/io.py index 09e4794da4..54e2237aa0 100644 --- a/e3sm_diags/driver/utils/io.py +++ b/e3sm_diags/driver/utils/io.py @@ -1,6 +1,7 @@ from __future__ import annotations import errno +<<<<<<< HEAD import json import os from typing import Callable @@ -8,12 +9,19 @@ import xarray as xr from e3sm_diags.driver.utils.type_annotations import MetricsDict +======= +import os + +import xarray as xr + +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) from e3sm_diags.logger import custom_logger from e3sm_diags.parameter.core_parameter import CoreParameter logger = custom_logger(__name__) +<<<<<<< HEAD def _save_data_metrics_and_plots( parameter: CoreParameter, plot_func: Callable, @@ -77,6 +85,8 @@ def _save_data_metrics_and_plots( ) +======= +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) def _write_vars_to_netcdf( parameter: CoreParameter, var_key, diff --git a/e3sm_diags/driver/utils/regrid.py b/e3sm_diags/driver/utils/regrid.py index e61476e95e..19cbcad12a 100644 --- a/e3sm_diags/driver/utils/regrid.py +++ b/e3sm_diags/driver/utils/regrid.py @@ -359,6 +359,7 @@ def regrid_z_axis_to_plevs( Replaces `e3sm_diags.driver.utils.general.convert_to_pressure_levels`. """ ds = dataset.copy() +<<<<<<< HEAD # Make sure that the input dataset has Z axis bounds, which are required for # getting grid positions during vertical regridding. @@ -369,12 +370,21 @@ def regrid_z_axis_to_plevs( z_axis = get_z_axis(ds[var_key]) z_long_name = z_axis.attrs.get("long_name") +======= + z_axis = get_z_axis(ds[var_key]) + z_long_name = z_axis.attrs.get("long_name") + +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) if z_long_name is None: raise KeyError( f"The vertical level ({z_axis.name}) for '{var_key}' does " "not have a 'long_name' attribute to determine whether it is hybrid " "or pressure." ) +<<<<<<< HEAD +======= + +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) z_long_name = z_long_name.lower() # Hybrid must be the first conditional statement because the long_name attr @@ -391,11 +401,16 @@ def regrid_z_axis_to_plevs( "'pressure', or 'isobaric'." ) +<<<<<<< HEAD # Add bounds for the new, regridded Z axis if the length is greater than 1. # xCDAT does not support adding bounds for singleton coordinates. new_z_axis = get_z_axis(ds_plevs[var_key]) if len(new_z_axis) > 1: ds_plevs = ds_plevs.bounds.add_bounds("Z") +======= + # Add bounds for the new, regridded Z axis. + ds_plevs = ds_plevs.bounds.add_bounds(axis="Z") +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) return ds_plevs @@ -432,12 +447,28 @@ def _hybrid_to_plevs( ----- Replaces `e3sm_diags.driver.utils.general.hybrid_to_plevs`. """ +<<<<<<< HEAD # TODO: mb units are always expected, but we should consider checking # the units to confirm whether or not unit conversion is needed. +======= + # TODO: Do we need to convert the Z axis to mb units if it is in PA? Or + # do we always expect units to be in mb? +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) z_axis, _ = xc.create_axis("lev", plevs, generate_bounds=False) pressure_grid = xc.create_grid(z=z_axis) pressure_coords = _hybrid_to_pressure(ds, var_key) +<<<<<<< HEAD +======= + + # Make sure that the input dataset has Z axis bounds, which are required for + # getting grid positions during vertical regridding. + try: + ds.bounds.get_bounds("Z") + except KeyError: + ds = ds.bounds.add_bounds("Z") + +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) # Keep the "axis" and "coordinate" attributes for CF mapping. with xr.set_options(keep_attrs=True): result = ds.regridder.vertical( @@ -498,7 +529,11 @@ def _hybrid_to_pressure(ds: xr.Dataset, var_key: str) -> xr.DataArray: "'hyam' and/or 'hybm' to use for reconstructing to pressure data." ) +<<<<<<< HEAD ps = _convert_dataarray_units_to_mb(ps) +======= + ps = _convert_units_to_mb(ps) +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) pressure_coords = hyam * p0 + hybm * ps pressure_coords.attrs["units"] = "mb" @@ -565,13 +600,23 @@ def _pressure_to_plevs( ----- Replaces `e3sm_diags.driver.utils.general.pressure_to_plevs`. """ +<<<<<<< HEAD # Convert pressure coordinates and bounds to mb if it is not already in mb. ds = _convert_dataset_units_to_mb(ds, var_key) +======= +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) # Create the output pressure grid to regrid to using the `plevs` array. z_axis, _ = xc.create_axis("lev", plevs, generate_bounds=False) pressure_grid = xc.create_grid(z=z_axis) +<<<<<<< HEAD +======= + # Convert pressure coordinates to mb if it is not already in mb. + lev_key = xc.get_dim_keys(ds[var_key], axis="Z") + ds[lev_key] = _convert_units_to_mb(ds[lev_key]) + +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) # Keep the "axis" and "coordinate" attributes for CF mapping. with xr.set_options(keep_attrs=True): result = ds.regridder.vertical( @@ -584,6 +629,7 @@ def _pressure_to_plevs( return result +<<<<<<< HEAD def _convert_dataset_units_to_mb(ds: xr.Dataset, var_key: str) -> xr.Dataset: """Convert a dataset's Z axis and bounds to mb if they are not in mb. @@ -635,6 +681,12 @@ def _convert_dataarray_units_to_mb(da: xr.DataArray) -> xr.DataArray: Unit conversion formulas: * hPa = mb +======= +def _convert_units_to_mb(da: xr.DataArray) -> xr.DataArray: + """Convert DataArray to mb (millibars) if not in mb. + + Unit conversion formulas: +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) * mb = Pa / 100 * Pa = (mb * 100) @@ -661,19 +713,32 @@ def _convert_dataarray_units_to_mb(da: xr.DataArray) -> xr.DataArray: if units is None: raise ValueError( +<<<<<<< HEAD f"'{da.name}' has no 'units' attribute to determine if data is in'mb', " "'hPa', or 'Pa' units." ) if units == "Pa": +======= + "'{ps.name}' has no 'units' attribute to determine if data is in 'mb' or " + "'Pa' units." + ) + + if units == "mb": + pass + elif units == "Pa": +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) with xr.set_options(keep_attrs=True): da = da / 100.0 da.attrs["units"] = "mb" +<<<<<<< HEAD elif units == "hPa": da.attrs["units"] = "mb" elif units == "mb": pass +======= +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) else: raise ValueError( f"'{da.name}' should be in 'mb' or 'Pa' (which gets converted to 'mb'), " diff --git a/e3sm_diags/metrics/metrics.py b/e3sm_diags/metrics/metrics.py index 68e5a4fc13..44ee278bd7 100644 --- a/e3sm_diags/metrics/metrics.py +++ b/e3sm_diags/metrics/metrics.py @@ -1,6 +1,9 @@ """This module stores functions to calculate metrics using Xarray objects.""" +<<<<<<< HEAD from __future__ import annotations +======= +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) from typing import List import xarray as xr @@ -30,9 +33,13 @@ def get_weights(ds: xr.Dataset): return ds.spatial.get_weights(axis=["X", "Y"]) +<<<<<<< HEAD def spatial_avg( ds: xr.Dataset, var_key: str, as_list: bool = True ) -> List[float] | xr.DataArray: +======= +def spatial_avg(ds: xr.Dataset, var_key: str) -> List[float]: +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) """Compute a variable's weighted spatial average. Parameters @@ -41,6 +48,7 @@ def spatial_avg( The dataset containing the variable. var_key : str The key of the varible. +<<<<<<< HEAD as_list : bool Return the spatial average as a list of floats, by default True. If False, return an xr.DataArray. @@ -48,6 +56,12 @@ def spatial_avg( Returns ------- List[float] | xr.DataArray +======= + + Returns + ------- + List[float] +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) The spatial average of the variable based on the specified axis. Raises @@ -62,10 +76,16 @@ def spatial_avg( ds_avg = ds.spatial.average(var_key, axis=AXES, weights="generate") results = ds_avg[var_key] +<<<<<<< HEAD if as_list: return results.data.tolist() return results +======= + results_list = results.data.tolist() + + return results_list +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) def std(ds: xr.Dataset, var_key: str) -> List[float]: diff --git a/e3sm_diags/parameter/core_parameter.py b/e3sm_diags/parameter/core_parameter.py index 4e97e0de45..501579a871 100644 --- a/e3sm_diags/parameter/core_parameter.py +++ b/e3sm_diags/parameter/core_parameter.py @@ -12,6 +12,10 @@ logger = custom_logger(__name__) +from e3sm_diags.derivations.derivations import DerivedVariablesMap +from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQ +from e3sm_diags.driver.utils.regrid import REGRID_TOOLS + if TYPE_CHECKING: from e3sm_diags.driver.utils.dataset_xr import Dataset diff --git a/e3sm_diags/plot/lat_lon_plot.py b/e3sm_diags/plot/lat_lon_plot.py index a6b8a0da02..146f3affd1 100644 --- a/e3sm_diags/plot/lat_lon_plot.py +++ b/e3sm_diags/plot/lat_lon_plot.py @@ -20,18 +20,28 @@ def plot( +<<<<<<< HEAD parameter: CoreParameter, +======= +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) da_test: xr.DataArray, da_ref: xr.DataArray | None, da_diff: xr.DataArray | None, metrics_dict: MetricsDict, +<<<<<<< HEAD +======= + parameter: CoreParameter, +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) ): """Plot the variable's metrics generated for the lat_lon set. Parameters ---------- +<<<<<<< HEAD parameter : CoreParameter The CoreParameter object containing plot configurations. +======= +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) da_test : xr.DataArray The test data. da_ref : xr.DataArray | None @@ -40,6 +50,11 @@ def plot( The difference between ``ds_test_regrid`` and ``ds_ref_regrid``. metrics_dict : Metrics The metrics. +<<<<<<< HEAD +======= + parameter : CoreParameter + The CoreParameter object containing plot configurations. +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) """ fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18) diff --git a/e3sm_diags/plot/utils.py b/e3sm_diags/plot/utils.py index c4057cbbe6..607d390fd5 100644 --- a/e3sm_diags/plot/utils.py +++ b/e3sm_diags/plot/utils.py @@ -101,7 +101,11 @@ def _add_colormap( fig: plt.figure, parameter: CoreParameter, color_map: str, +<<<<<<< HEAD contour_levels: List[float], +======= + contour_levels: List[str], +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) title: Tuple[str | None, str, str], metrics: Tuple[float, ...], ): @@ -123,7 +127,11 @@ def _add_colormap( The CoreParameter object containing plot configurations. color_map : str The colormap styling to use (e.g., "cet_rainbow.rgb"). +<<<<<<< HEAD contour_levels : List[float] +======= + contour_levels : List[str] +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) The map contour levels. title : Tuple[str | None, str, str] A tuple of strings to form the title of the colormap, in the format @@ -427,12 +435,22 @@ def _determine_tick_step(degrees_covered: float) -> int: return 1 +<<<<<<< HEAD def _get_contour_label_format_and_pad(c_levels: List[float]) -> Tuple[str, int]: +======= +def _get_contour_label_format_and_pad( + c_levels: List[str] | List[str | float], +) -> Tuple[str, int]: +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) """Get the label format and padding for each contour level. Parameters ---------- +<<<<<<< HEAD c_levels : List[float] +======= + c_levels : List[str] | List[str | float] +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) The contour levels. Returns diff --git a/tests/e3sm_diags/driver/utils/test_dataset_xr.py b/tests/e3sm_diags/driver/utils/test_dataset_xr.py index 1fdf6de3e3..4663ccde07 100644 --- a/tests/e3sm_diags/driver/utils/test_dataset_xr.py +++ b/tests/e3sm_diags/driver/utils/test_dataset_xr.py @@ -205,6 +205,7 @@ def test_property_is_timeseries_returns_false_and_is_climo_returns_true_for_ref( assert ds.is_climo +<<<<<<< HEAD class TestGetReferenceClimoDataset: @pytest.fixture(autouse=True) def setup(self, tmp_path): @@ -380,6 +381,8 @@ def test_returns_test_dataset_as_default_value_if_climo_dataset_not_found(self): assert ds.model_only +======= +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) class TestGetClimoDataset: @pytest.fixture(autouse=True) def setup(self, tmp_path): diff --git a/tests/e3sm_diags/driver/utils/test_regrid.py b/tests/e3sm_diags/driver/utils/test_regrid.py index 870de6a6ab..a75b13b2f5 100644 --- a/tests/e3sm_diags/driver/utils/test_regrid.py +++ b/tests/e3sm_diags/driver/utils/test_regrid.py @@ -226,7 +226,11 @@ def test_regrids_to_first_dataset_with_equal_latitude_points(self, tool): ds_b = generate_lev_dataset("pressure", pressure_vars=False) result_a, result_b = align_grids_to_lower_res( +<<<<<<< HEAD ds_a, ds_b, "so", tool, "conservative" +======= + ds_a, ds_b, "so", "xesmf", "conservative" +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) ) expected_a = ds_a.copy() @@ -465,11 +469,18 @@ def test_regrids_pressure_coordinates_to_pressure_levels_with_Pa_units( attrs={"xcdat_bounds": "True"}, ) +<<<<<<< HEAD # Update mb to Pa so this test can make sure conversions to mb are done. ds_pa = ds.copy() with xr.set_options(keep_attrs=True): ds_pa["lev"] = ds_pa.lev * 100 ds_pa["lev_bnds"] = ds_pa.lev_bnds * 100 +======= + # Update from Pa to mb. + ds_pa = ds.copy() + with xr.set_options(keep_attrs=True): + ds_pa["lev"] = ds_pa.lev * 100 +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) ds_pa.lev.attrs["units"] = "Pa" result = regrid_z_axis_to_plevs(ds_pa, "so", self.plevs) diff --git a/tests/e3sm_diags/metrics/test_metrics.py b/tests/e3sm_diags/metrics/test_metrics.py index 141b5e0d67..9de38fc861 100644 --- a/tests/e3sm_diags/metrics/test_metrics.py +++ b/tests/e3sm_diags/metrics/test_metrics.py @@ -75,6 +75,7 @@ def test_returns_spatial_avg_for_x_y(self): np.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) +<<<<<<< HEAD def test_returns_spatial_avg_for_x_y_as_xr_dataarray(self): expected = [1.5, 1.333299, 1.5] result = spatial_avg(self.ds, "ts", as_list=False) @@ -82,6 +83,8 @@ def test_returns_spatial_avg_for_x_y_as_xr_dataarray(self): assert isinstance(result, xr.DataArray) np.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) +======= +>>>>>>> 33dbaf28 (CDAT Migration Phase 2: Refactor core utilities and `lat_lon` set (#677)) class TestStd: @pytest.fixture(autouse=True)