From df3efe5f92631285bf88938ffa2644c1cc693fdc Mon Sep 17 00:00:00 2001 From: Renu Singh Date: Tue, 28 Jan 2025 13:35:17 +0100 Subject: [PATCH] xarray output for power spec and rank hist metrics (#15) * xarray wrapper creates dataset with var as data vars instead of metric as data vars to support variables with diff dimensions (ie. surface variables only have 1 level) * convert to xarray also creates xarray with vars as data vars * plot reads new format of xarray * remove option to return raw dict confusing and not used rather return metrics in formatted xarray anyways * fix convert metric dict with metadata * plot rankhist like paper * migrate rankhist to output xarray not labeled dict * fix timedelta in era5 metrics * handle extra dimensions in xarray when plotting * plot option color per model * fix plot var titles * fix rankhist check in plot.py * add option to filter files in pred_path for eval * convert to xarray dict reports readable error * document eval commands used for rankhist --- docs/archesweather/eval.md | 32 +++ geoarches/evaluation/eval_multistep.py | 29 +-- geoarches/evaluation/plot.py | 199 ++++++++++++++---- geoarches/metrics/brier_skill_score.py | 4 - geoarches/metrics/ensemble_metrics.py | 4 - geoarches/metrics/label_wrapper.py | 73 +++---- geoarches/metrics/rank_histogram.py | 73 ++++--- geoarches/metrics/spherical_power_spectrum.py | 37 ++-- mkdocs.yaml | 1 + poetry.lock | 25 ++- pyproject.toml | 3 +- tests/metrics/test_label_wrapper.py | 102 ++++++--- tests/metrics/test_rank_histogram.py | 56 ++--- .../metrics/test_spherical_power_spectrum.py | 4 +- 14 files changed, 437 insertions(+), 205 deletions(-) create mode 100644 docs/archesweather/eval.md diff --git a/docs/archesweather/eval.md b/docs/archesweather/eval.md new file mode 100644 index 0000000..59040a1 --- /dev/null +++ b/docs/archesweather/eval.md @@ -0,0 +1,32 @@ +# Evaluate + +Set model run name (used in hydra argument `++name=NAME`): +```sh +NAME=archesweathergen +``` + +## Commands to compute metrics + +### Rank histogram +```sh +python -m geoarches.evaluation.eval_multistep \ +--pred_path evalstore/${NAME}/ \ +--output_dir evalstore/${NAME}_metrics/ \ +--groundtruth_path data/era5_240/full/ \ +--multistep 10 --num_workers 4 \ +--metrics era5_rank_histogram_50_members +``` + +## Commands to plot (WIP) + +### Rank histogram +```sh +python -m geoarches.evaluation.plot --output_dir plots/ +--metric_paths evalstore/${NAME}_metrics/test-multistep=10-era5_rank_histogram_50_members.nc +--model_names ArchesWeatherGen \ +--model_colors red \ +--metrics rankhist \ +--vars Z500:geopotential:level:500 T850:temperature:level:850 Q700:specific_humidity:level:700 U850:u_component_of_wind:level:850 V850:v_component_of_wind:level:850 \ +--rankhist_prediction_timedeltas 1 7 \ +--figsize 10 4 +``` \ No newline at end of file diff --git a/geoarches/evaluation/eval_multistep.py b/geoarches/evaluation/eval_multistep.py index 23c123a..4b20676 100644 --- a/geoarches/evaluation/eval_multistep.py +++ b/geoarches/evaluation/eval_multistep.py @@ -76,6 +76,12 @@ def main(): required=True, help="Directory or file path to find model predictions.", ) + parser.add_argument( + "--pred_filename_filter", + nargs="*", # Accepts 0 or more arguments as a list. + type=str, + help="Substring(s) in filenames under --pred_path to keep files to run inference on.", + ) parser.add_argument( "--groundtruth_path", type=str, @@ -129,8 +135,6 @@ def main(): torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" - print("Reading from predictions path:", args.pred_path) - # Output directory to save evaluation. output_dir = args.output_dir Path(output_dir).mkdir(parents=True, exist_ok=True) @@ -164,10 +168,16 @@ def main(): print(f"Reading {len(ds_test.files)} files from groundtruth path: {args.groundtruth_path}.") # Predictions. + def _pred_filename_filter(filename): + for substring in args.pred_filename_filter: + if substring not in filename: + return False + return True + if not args.eval_clim: ds_pred = era5.Era5Dataset( path=args.pred_path, - filename_filter=(lambda x: True), # Update filename_filter to filter within pred_path. + filename_filter=_pred_filename_filter, # Update filename_filter to filter within pred_path. variables=variables, return_timestamp=True, dimension_indexers=dict( @@ -220,7 +230,6 @@ def __getitem__(self, idx): pressure_levels=[500, 700, 850], lead_time_hours=24 if args.multistep else None, rollout_iterations=args.multistep, - return_raw_dict=True, ).to(device) print(f"Computing: {metrics.keys()}") @@ -256,7 +265,7 @@ def __getitem__(self, idx): metric.update(target.to(device), pred.to(device)) for metric_name, metric in metrics.items(): - raw_dict, labelled_metric_output = metric.compute() + labelled_metric_output = metric.compute() if Path(args.pred_path).is_file(): output_filename = f"{Path(args.pred_path).stem}-{metric_name}" @@ -276,19 +285,15 @@ def __getitem__(self, idx): ds = convert_metric_dict_to_xarray(labelled_dict, extra_dimensions) # Write labeled dict. - labelled_dict["groundtruth_path"] = args.groundtruth_path - labelled_dict["predictions_path"] = args.pred_path + labelled_dict["metadata"] = dict( + groundtruth_path=args.groundtruth_path, predictions_path=args.pred_path + ) torch.save(labelled_dict, Path(output_dir).joinpath(f"{output_filename}.pt")) else: ds = labelled_metric_output # Write xr dataset. ds.to_netcdf(Path(output_dir).joinpath(f"{output_filename}.nc")) - # Write raw score dict. - raw_dict["groundtruth_path"] = args.groundtruth_path - raw_dict["predictions_path"] = args.pred_path - torch.save(raw_dict, Path(output_dir).joinpath(f"{output_filename}-raw.pt")) - if __name__ == "__main__": main() diff --git a/geoarches/evaluation/plot.py b/geoarches/evaluation/plot.py index a6acb04..bb2beb7 100644 --- a/geoarches/evaluation/plot.py +++ b/geoarches/evaluation/plot.py @@ -13,6 +13,7 @@ import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import numpy as np +import seaborn as sns import torch import xarray as xr @@ -24,19 +25,38 @@ "fcrps": dict(y_label="Fair CRPS"), "spskr": dict(y_label="Spread/skill", horizontal_reference=1), "brierskillscore": dict(y_label="Brier skill score", horizontal_reference=0), - "rankhist": dict(y_label="Frequency", horizontal_reference=1), + "rankhist": dict( + y_label="log(Frequency)", x_label="Rank (normalized)", horizontal_reference=0 + ), } HIGH_QUANTILES = ["99.0%", "99.9%", "99.99%"] LOW_QUANTILES = ["1.0%", "0.1%", "0.01%"] +colors = [ + "blue", + "orange", + "green", + "red", + "purple", + "brown", + "pink", + "grey", + "lightgreen", + "lightblue", +] +COLORS = dict(zip(colors, sns.color_palette())) +COLORS["black"] = [0, 0, 0] + def plot_metric( data_dict: dict[str, xr.Dataset], - vars: list[str], + vars: list[tuple], metric_name: str, y_label: str | None = None, + x_label: str | None = None, horizontal_reference: float | None = None, + plot_kwargs={}, figsize=(10, 4), debug: bool = False, ): @@ -47,10 +67,13 @@ def plot_metric( data_dict: Mapping from model name to xarray dataset holding metrics. Dataset variables are metric Dataset dimensions include `variable` and `prediction_timedelta`. - vars: list of variables to read from dimension `variable` in xarray dataset. + vars: list of variables and optional extra dimensions to read from xarray dataset. + Each elemnent is a tuple (var, {dim_name:dim_value,...}). If no extra dimensions, (var, {}). metric_name: Name of metric to read from dataset variables. y_label: (Optional) to label y axis of the figure. + x_label: (Optional) to label x axis of the figure. horizontal_reference: (Optional) y value for horizontal dashed line on the plots. + plot_kwargs: kwargs to pass into plt.plot() such as color. figsize: Figure size. debug: Whether to print debug statements. """ @@ -61,8 +84,11 @@ def plot_metric( if y_label: fig.supylabel(y_label) + if x_label: + fig.supxlabel(x_label) for i, var in enumerate(vars): + alias, var, extra_dims = var if debug: print(var) @@ -70,37 +96,38 @@ def plot_metric( if len(vars) % 2 == 1: col += row # offset bottom row. ax = fig.add_subplot(gs[row, col : col + 2]) - ax.set_title(var) + ax.set_title(alias) ax.set_xlabel("Lead time (days)") for model, ds in data_dict.items(): if metric_name == "rmse": - scores = ds["mse"].sel(variable=var) + scores = ds[var].sel(metric="mse", **extra_dims) scores = np.sqrt(scores) else: - scores = ds[metric_name].sel(variable=var) + scores = ds[var].sel(metric=metric_name) if debug: print(model, scores) days = ds.prediction_timedelta.dt.days - ax.plot(days, scores, label=model) + ax.plot(days, scores, label=model, **plot_kwargs[model]) ax.grid(True) if horizontal_reference is not None: ax.axhline(y=horizontal_reference, color="gray", linestyle="--", linewidth=1) - plt.tight_layout() # ensure titles and other plot elements don't overlap. plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.5)) def plot_brier_metric( data_dict: dict[str, xr.Dataset], - vars: list[str], + vars: list[tuple], quantile_levels: list[str], # high or low metric_name: str, y_label: str | None = None, + x_label: str | None = None, horizontal_reference: float | None = None, + plot_kwargs={}, figsize=(10, 5), debug: bool = False, ): @@ -110,12 +137,15 @@ def plot_brier_metric( data_dict: Mapping from model name to xarray dataset holding metrics. Dataset variables are metric Dataset dimensions include `variable` and `prediction_timedelta`. - vars: list of variables to read from dimension `variable` in xarray dataset. + vars: list of variables and optional extra dimensions to read from xarray dataset. + Each elemnent is a tuple (var, {dim_name:dim_value,...}). If no extra dimensions, (var, {}). quantile_levels: Whether to plot high quantiles (ie. 99%, 99.9%, 99.99% levels) or low for each variable. Same length as `vars` list arg. metric_name: Name of metric to read from dataset variables. y_label: (Optional) to label y axis of the figure. + x_label: (Optional) to label x axis of the figure. horizontal_reference: (Optional) y value for horizontal dashed line on the plots. + plot_kwargs: kwargs to pass into plt.plot() such as color. figsize: Figure size. debug: Whether to print debug statements. """ @@ -123,27 +153,31 @@ def plot_brier_metric( # Create a grid of subplots: quantile vs. variable. fig, axs = plt.subplots(len(HIGH_QUANTILES), len(vars), figsize=figsize) - # y labels. + # Axis labels. if y_label: fig.supylabel(y_label) + if x_label: + fig.supxlabel(x_label) + for q in range(3): axs[q, 0].set_ylabel(f"{10**-q}% extremes") for model, ds in data_dict.items(): for i, (var, quantiles) in enumerate(zip(vars, quantiles_per_var)): + alias, var, extra_dims = var if debug: print(var, quantiles) - axs[0, i].set_title(f"{var}\nextreme {quantile_levels[i]}") + axs[0, i].set_title(f"{alias}\nextreme {quantile_levels[i]}") axs[2, i].set_xlabel("Lead time (days)") for q, quantile in enumerate(quantiles): - scores = ds[metric_name].sel(variable=var, quantile=quantile) + scores = ds[var].sel(metric=metric_name, quantile=quantile, **extra_dims) if debug: print(model, scores) days = ds.prediction_timedelta.dt.days - axs[q, i].plot(days, scores, label=model) + axs[q, i].plot(days, scores, label=model, **plot_kwargs[model]) axs[q, i].grid(True) if horizontal_reference is not None: @@ -151,17 +185,18 @@ def plot_brier_metric( y=horizontal_reference, color="gray", linestyle="--", linewidth=1 ) - plt.tight_layout() # ensure titles and other plot elements don't overlap. plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.5)) def plot_rankhist( data_dict: dict[str, xr.Dataset], - vars: list[str], + vars: list[tuple], prediction_timedeltas_days: list[int], metric_name: str = "rankhist", y_label: str | None = None, + x_label: str | None = None, horizontal_reference: float | None = None, + plot_kwargs={}, figsize=(10, 5), debug: bool = False, ): @@ -171,38 +206,85 @@ def plot_rankhist( data_dict: Mapping from model name to xarray dataset holding metrics. Dataset variables are metric Dataset dimensions include `variable` and `prediction_timedelta`. - vars: list of variables to read from dimension `variable` in xarray dataset. + vars: list of variables and optional extra dimensions to read from xarray dataset. + Each elemnent is a tuple (var, {dim_name:dim_value,...}). If no extra dimensions, (var, {}). prediction_timedeltas_days: List of lead times (in days) to plot. metric_name: Name of metric to read from dataset variables. y_label: (Optional) to label y axis of the figure. + x_label: (Optional) to label x axis of the figure. horizontal_reference: (Optional) y value for horizontal dashed line on the plots. + plot_kwargs: kwargs to pass into plt.plot() such as color. figsize: Figure size. debug: Whether to print debug statements. """ # Create a grid of subplots: vars vs. lead time. fig, axs = plt.subplots(len(prediction_timedeltas_days), len(vars), figsize=figsize) - # y labels. + # Axis labels. if y_label: fig.supylabel(y_label) + if x_label: + fig.supxlabel(x_label) for i, days in enumerate(prediction_timedeltas_days): - axs[i, 0].set_ylabel(f"{days} days") + if days == 1: + axs[i, 0].set_ylabel(f"{days} day") + else: + axs[i, 0].set_ylabel(f"{days} days") for model, ds in data_dict.items(): + if debug: + print(model) for col, var in enumerate(vars): - axs[0, col].set_title(var) + if debug: + print(var) + alias, var, extra_dims = var + axs[0, col].set_title(alias) for row, days in enumerate(prediction_timedeltas_days): - scores = ds[metric_name].sel( - variable=var, prediction_timedelta=timedelta(days=days) + scores = ds[var].sel( + metric=metric_name, prediction_timedelta=timedelta(days=days), **extra_dims ) - axs[row, col].plot(scores, label=model) + axs[row, col].plot(np.log(scores), label=model, **plot_kwargs[model]) axs[row, col].grid(True) + # Normalize rank. + num_ranks = len(scores) + xticks = np.array([0, num_ranks / 2, num_ranks]) + axs[row, col].set_xticks(xticks, xticks / num_ranks) if horizontal_reference is not None: axs[row, col].axhline( y=horizontal_reference, color="gray", linestyle="--", linewidth=1 ) - plt.tight_layout() # ensure titles and other plot elements don't overlap. - plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.5)) + handles, labels = axs[0, 0].get_legend_handles_labels() + fig.legend(handles, labels, loc="lower center", bbox_to_anchor=(0.5, -0.1)) + + +def parse_vars(vars): + """Parse the --vars argument into a list of tuples with 'var' and dictionary with 'dim_name' to 'dim_value'. + + Example input: temperature:level:500, 2m_temperature + Example output: (temperature, dict(level=500)), (2m_temperature, {}) + """ + if not vars: + return None + + def _cast(dim_value): + if dim_value.isdigit(): + return int(dim_value) + return dim_value + + parsed_vars = [] + for var in vars: + if ":" not in var: + parsed_vars.append((var, var, {})) + continue + var_and_dims = var.split(":") + alias, var_name, dims = var_and_dims[0], var_and_dims[1], var_and_dims[2:] + if len(dims) % 2 != 0: + raise ValueError( + f"--vars list requires elements to be in this format '::::::::<...'." + "where alias is the name for the plot, var is the name of the dict/xarray variable and dim1,dim2 are extra dimensions to select with ds.sel()." + "Example: '--vars Z500:geopotential:level:500 Q700:specific_humidity:level:700 T2m:2m_temperature", ) parser.add_argument( "--figsize", @@ -269,7 +361,7 @@ def main(): nargs="+", # Accepts 1 or more arguments as a list. type=int, default=[1, 3, 5, 10, 15], - help="Used only for plotting `rankhist` metric. For each variable, which lead times to plot.", + help="Used only for plotting `rankhist` metric. For each variable, which lead times to plot. Example: --rankhist_prediction_timedeltas 1 7.", ) parser.add_argument( "--force", action="store_true", help="Force save plots if file already exists." @@ -288,6 +380,9 @@ def main(): assert len(args.metric_paths) == len(args.model_names_for_legend), ( "Len of metric_paths != len of model_names_for_legend." ) + assert len(args.model_colors) == len(args.model_names_for_legend), ( + "Len of model_colors != len of model_names_for_legend." + ) output_dir = args.output_dir Path(output_dir).mkdir(parents=True, exist_ok=True) @@ -305,22 +400,29 @@ def main(): extra_dimensions = ["prediction_timedelta"] if "brier" in metric_path: extra_dimensions = ["quantile", "prediction_timedelta"] - if "rankhist" in metric_path: + if "rankhist" in metric_path or "rank_hist" in metric_path: extra_dimensions = ["bins", "prediction_timedelta"] ds = convert_metric_dict_to_xarray(labeled_dict, extra_dimensions) data[model_name].append(ds) + plot_kwargs = defaultdict(dict) + for model_name, color in zip(args.model_names_for_legend, args.model_colors): + plot_kwargs[model_name] = dict(color=color) + for model, ds_list in data.items(): merged_ds = xr.merge(ds_list) data[model] = merged_ds - vars = ds.variable.values - metrics = list(merged_ds.data_vars) + metrics = ds.metric.values + vars = list(merged_ds.data_vars) if args.save_xr_dataset: save_file = Path(output_dir).joinpath(f"{model}_metrics.nc") + if save_file.exists(): + raise ValueError(f"File {save_file} already exists. Did not save xr dataset.") + merged_ds.to_netcdf(save_file) - vars = args.vars or vars + vars = parse_vars(args.vars or vars) metrics = args.metrics or metrics for metric_name in metrics: if args.debug: @@ -340,6 +442,7 @@ def main(): **kwargs, figsize=args.figsize, debug=args.debug, + plot_kwargs=plot_kwargs, ) elif "rankhist" in metric_name: plot_rankhist( @@ -350,14 +453,28 @@ def main(): **kwargs, figsize=args.figsize, debug=args.debug, + plot_kwargs=plot_kwargs, ) else: - plot_metric(data, vars, metric_name, **kwargs, figsize=args.figsize, debug=args.debug) + plot_metric( + data, + vars, + metric_name, + **kwargs, + figsize=args.figsize, + debug=args.debug, + plot_kwargs=plot_kwargs, + ) + plt.tight_layout() # ensure titles and other plot elements don't overlap. + plt.style.use("seaborn-v0_8-paper") + plt.rcParams["font.family"] = "DejaVu Sans" save_file = Path(output_dir).joinpath(f"{metric_name}.png") if save_file.exists(): if not args.force: - raise ValueError(f"File {save_file} already exists. Did not save plot.") + raise ValueError( + f"File {save_file} already exists. Did not save plot. Use --force to overwrite." + ) plt.savefig(save_file, bbox_inches="tight") diff --git a/geoarches/metrics/brier_skill_score.py b/geoarches/metrics/brier_skill_score.py index 208c0c1..d2218bb 100644 --- a/geoarches/metrics/brier_skill_score.py +++ b/geoarches/metrics/brier_skill_score.py @@ -183,7 +183,6 @@ def __init__( pressure_levels=era5.pressure_levels, lead_time_hours: None | int = None, rollout_iterations: None | int = None, - return_raw_dict: bool = False, save_memory: bool = False, ): """ @@ -202,7 +201,6 @@ def __init__( rollout_iterations: Size of timedelta dimension (number of rollout iterations in multistep predictions). Set to explicitly handle metrics computed on predictions from multistep rollout. See param `lead_time_hours`. - return_raw_dict: Whether to also return the raw output from the metrics. """ # Quantiles for each var across gridpoints and times. with resources.as_file(resources.files(geoarches_stats).joinpath(quantiles_filepath)) as f: @@ -258,7 +256,6 @@ def _add_quantile_index(variable_indices): lead_time_hours=lead_time_hours, rollout_iterations=rollout_iterations, ), - return_raw_dict=return_raw_dict, ) if level_variables: kwargs["level"] = LabelDictWrapper( @@ -275,6 +272,5 @@ def _add_quantile_index(variable_indices): lead_time_hours=lead_time_hours, rollout_iterations=rollout_iterations, ), - return_raw_dict=return_raw_dict, ) super().__init__(**kwargs) diff --git a/geoarches/metrics/ensemble_metrics.py b/geoarches/metrics/ensemble_metrics.py index f29f79c..9782264 100644 --- a/geoarches/metrics/ensemble_metrics.py +++ b/geoarches/metrics/ensemble_metrics.py @@ -178,7 +178,6 @@ def __init__( save_memory: bool = False, lead_time_hours: None | int = None, rollout_iterations: None | int = None, - return_raw_dict: bool = False, ): """ Args: @@ -195,7 +194,6 @@ def __init__( rollout_iterations: Size of timedelta dimension (number of rollout iterations in multistep predictions). Set to explicitly handle metrics computed on predictions from multistep rollout. See param `lead_time_hours`. - return_raw_dict: Whether to also return the raw output from the metrics. """ # Initialize separate metrics for level vars and surface vars. kwargs = {} @@ -209,7 +207,6 @@ def __init__( lead_time_hours=lead_time_hours, rollout_iterations=rollout_iterations, ), - return_raw_dict=return_raw_dict, ) if level_variables: level_ensemble_metric = EnsembleMetrics( @@ -223,6 +220,5 @@ def __init__( lead_time_hours=lead_time_hours, rollout_iterations=rollout_iterations, ), - return_raw_dict=return_raw_dict, ) super().__init__(**kwargs) diff --git a/geoarches/metrics/label_wrapper.py b/geoarches/metrics/label_wrapper.py index 2632b8c..1d4c31e 100644 --- a/geoarches/metrics/label_wrapper.py +++ b/geoarches/metrics/label_wrapper.py @@ -14,7 +14,7 @@ class LabelDictWrapper(Metric): """Wrapper class around metric for extracting metric outputs into a labelled dictionary. Helpful for WandB which needs to log single values. - Expects the wrapped metric to return a dictionary holding computed metrics: + Expects the metric to return a dictionary holding computed metrics: - keys: metric_name - values: torch tensors with shape (..., *(variable_index)) variable_index is passed in with param `variable_indices` @@ -39,14 +39,12 @@ class LabelDictWrapper(Metric): dict mapping metric name to tensors that have shape (..., *(variable_index)). variable_indices: Mapping from variable name to index (ie. var, lev) into tensor holding computed metric. ie. dict(T2m=(2, 0), U10=(0, 0), V10=(1, 0), SP=(3, 0)). - return_raw_dict: Whether to also return the raw output from the metrics (along with the labelled dict). """ def __init__( self, metric: Metric, variable_indices: Dict[str, tuple], - return_raw_dict: bool = False, ): super().__init__() if not isinstance(metric, Metric): @@ -56,8 +54,6 @@ def __init__( self.metric = metric self.variable_indices = variable_indices - self.return_raw_dict = return_raw_dict - def _convert(self, raw_metric_dict: Dict[str, Tensor]): # Label metrics. labeled_dict = dict() @@ -70,11 +66,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: self.metric.update(*args, **kwargs) def compute(self) -> Dict[str, Tensor]: - raw_metric_dict = self.metric.compute() - if self.return_raw_dict: - return raw_metric_dict, self._convert(raw_metric_dict) - else: - return self._convert(raw_metric_dict) + return self._convert(self.metric.compute()) def reset(self) -> None: """Reset metric.""" @@ -123,6 +115,8 @@ def convert_metric_dict_to_xarray( where the separator between dimensions is an underscore. extra_dimensions: list of dimension names, if any extra beyond and . """ + if "metadata" in labeled_dict: + del labeled_dict["metadata"] def _convert_coord(name, value): if "timedelta" in name: @@ -140,36 +134,36 @@ def _convert_coord(name, value): labels = label.split("_") if len(labels) - 2 != len(extra_dimensions): raise ValueError( - f"Expected length of extra_dimensions for key {label} to be: {len(labels) - 2}." + f"Expected length of extra_dimensions for key {label} to be: {len(labels) - 2}. Got extra_dimensions={extra_dimensions}." ) metrics.add(labels[0]) variables.add(labels[1]) for i, dim in enumerate(extra_dimensions): coords[dim].add(labels[i + 2]) - dimension_shape = [len(coord) for coord in (variables, *coords.values())] + dimension_shape = [len(coord) for coord in (metrics, *coords.values())] # Sort coordinates. - variables = sorted(list(variables)) + metrics = sorted(list(metrics)) for k, coord in coords.items(): coords[k] = sorted(list(coord), key=lambda x: _convert_coord(k, x)) # Aggregate data arrays by variable. - dimensions = ["variable"] + extra_dimensions + dimensions = ["metric"] + extra_dimensions data_arrays = {} - for metric in metrics: + for var in variables: data = [] - for dims in itertools.product(variables, *coords.values()): - var, other_dims = dims[0], dims[1:] + for dims in itertools.product(metrics, *coords.values()): + metric, other_dims = dims[0], dims[1:] other_dims = "_" + "_".join(other_dims) if other_dims else "" key = f"{metric}_{var}{other_dims}" data.append(labeled_dict[key]) data = np.array(data).reshape(dimension_shape) - data_arrays[metric] = (dimensions, data) + data_arrays[var] = (dimensions, data) # Prepare coordinates. for k, coord in coords.items(): coords[k] = [_convert_coord(k, x) for x in coord] - coords["variable"] = variables + coords["metric"] = metrics return xr.Dataset(data_vars=data_arrays, coords=coords) @@ -178,62 +172,61 @@ class LabelXarrayWrapper(Metric): """Wrapper class around metric for extracting metric outputs into a labelled xarray. Helpful for easier analysis. - Expects the wrapped metric to return a dictionary holding computed metrics: + Expects the metric to return a dictionary holding computed metrics: - keys: metric name - values: torch tensor with shape (dim1, dim2, ...) - Returns xarray of computed metrics: - - with metric as data variables and dim1, dim2, as coords. + LabelXarrayWrapper returns xarray of computed metrics: + - with variable as data variables and (metric, dim1, dim2) as dimensions. Warning: this class is not compatible with forward(), only use update() and compute(). See https://github.com/Lightning-AI/torchmetrics/issues/987#issuecomment-2419846736. Example: metric = LabelXarrayWrapper(EnsembleMetrics(preprocess=preprocess_fn), - coord_names = ['variable', 'level'], + dims = ['variable', 'level'], coords= (['T2m','U10m'], [500, 750, 800]) targets, preds = torch.tensor(batch, var, lev, lat, lon), torch.tensor(batch, var, lev, lat, lon) metric.update(targets, preds) - xr_dataset = metric.compute() # EnsembleMetrics returns {"mse": torch.tensor(var, lev) } + xr_dataset = metric.compute() # EnsembleMetrics internally returns {"mse": torch.tensor(var, lev) } Args: metric: base metric that should be wrapped. It is assumed that the metric outputs a dict mapping metric name to tensors that have shape (dim1, dim2, ...). - coord_names: Names of the dimensions returned by metric (same order as tensor shape and `coords`). - coords: Values for the dimensions returned by metric (same order as tensor shape and `coord_names`). - return_raw_dict: Whether to also return the raw output from the metrics (along with the labelled dict). + dims: Names of the dimensions returned by metric (same order as tensor shape and `coords`). + coords: Values for the dimensions returned by metric (same order as tensor shape and `dims`). """ def __init__( self, metric: Metric, - coord_names: Sequence[str], + dims: Sequence[str], coords: Sequence[Sequence], - return_raw_dict: bool = False, ): super().__init__() if not isinstance(metric, Metric): raise ValueError( - f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {metric}" + f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {metric}." + ) + if "variable" not in dims: + raise ValueError( + "One dimension needs to be named 'variable'. It will be used as variable in xarray dataset." ) self.metric = metric - self.coord_names = coord_names + self.dims = dims self.coords = coords - self.return_raw_dict = return_raw_dict - def _convert(self, raw_metric_dict: dict[str, torch.tensor]): + def _convert(self, metric_dict: dict[str, torch.tensor]) -> xr.Dataset: ds = xr.Dataset( - data_vars={metric: (self.coord_names, val) for metric, val in raw_metric_dict.items()}, - coords={coord_name: coord for coord_name, coord in zip(self.coord_names, self.coords)}, + data_vars={metric: (self.dims, val) for metric, val in metric_dict.items()}, + coords={dim: coord for dim, coord in zip(self.dims, self.coords)}, ) + ds = ds.to_array(dim="metric") # Stack metrics into new dim. + ds = ds.to_dataset(dim="variable") # Split into separate variables along var dimension. return ds def update(self, *args: Any, **kwargs: Any) -> None: self.metric.update(*args, **kwargs) def compute(self) -> dict[str, torch.tensor]: - raw_metric_dict = self.metric.compute() - if self.return_raw_dict: - return raw_metric_dict, self._convert(raw_metric_dict) - else: - return self._convert(raw_metric_dict) + return self._convert(self.metric.compute()) diff --git a/geoarches/metrics/rank_histogram.py b/geoarches/metrics/rank_histogram.py index ffa478b..9780a3c 100644 --- a/geoarches/metrics/rank_histogram.py +++ b/geoarches/metrics/rank_histogram.py @@ -1,3 +1,6 @@ +from datetime import timedelta +from typing import Callable + import numpy as np import torch from einops import rearrange @@ -5,7 +8,7 @@ from torchmetrics import Metric from geoarches.dataloaders import era5 -from geoarches.metrics.label_wrapper import LabelDictWrapper, add_timedelta_index +from geoarches.metrics.label_wrapper import LabelXarrayWrapper from .metric_base import TensorDictMetricBase @@ -28,13 +31,16 @@ class RankHistogram(Metric): targets: (batch, ..., lat, lon) preds: (batch, nmembers, ..., lat, lon) - Return dictionary of metrics reduced over batch, lat, lon. + Return: + dictionary of metrics reduced over batch, lat, lon. + metric will have shape (..., rank) where rank = n_members + 1. """ def __init__( self, n_members: int, data_shape: tuple = (4, 1), + preprocess: Callable | None = None, ): """ Args: @@ -42,8 +48,10 @@ def __init__( data_shape: Shape of tensor to hold computed metric. e.g. if targets are shape (batch, timedelta, var, lev, lat, lon) then data_shape = (timedelta, var, lev). This class computes metric across batch, lat, lon dimensions. + preprocess: Takes as input targets or predictions and returns processed tensor. """ Metric.__init__(self) + self.preprocess = preprocess self.n_members = n_members self.data_shape = data_shape @@ -70,6 +78,10 @@ def update(self, targets, preds) -> None: n_members = preds.shape[1] + if self.preprocess: + targets = self.preprocess(targets) + preds = self.preprocess(preds) + # Compute ranks of the targets with respect to ensemble predictions. # only works on cpu device = targets.device @@ -88,7 +100,7 @@ def update(self, targets, preds) -> None: # Count frequency of ranks across lat, lon, batch. # (Might have smarter ways at the expense of memory: https://stackoverflow.com/questions/69429586/how-to-get-a-histogram-of-pytorch-tensors-in-batches) - assert self.data_shape == ranks.shape[1:-2] + assert self.data_shape == ranks.shape[1:-2], f"{self.data_shape} != {ranks.shape[1:-2]}" ranks = rearrange(ranks, "b ... lat lon -> (b lat lon) (...)") bins = n_members + 1 num_histograms = ranks.shape[-1] @@ -132,7 +144,6 @@ def __init__( pressure_levels=era5.pressure_levels, lead_time_hours: None | int = None, rollout_iterations: None | int = None, - return_raw_dict: bool = False, ): """ Args: @@ -148,46 +159,46 @@ def __init__( rollout_iterations: Size of timedelta dimension (number of rollout iterations in multistep predictions). Set to explicitly handle metrics computed on predictions from multistep rollout. See param `lead_time_hours`. - return_raw_dict: Whether to also return the raw output from the metrics. """ + ranks = list(range(1, n_members + 2)) + # Whether to include prediction_timdelta dimension. if rollout_iterations: - surface_data_shape = (rollout_iterations, len(surface_variables), 1) + surface_data_shape = (rollout_iterations, len(surface_variables)) level_data_shape = (rollout_iterations, len(level_variables), len(pressure_levels)) + + surface_dims = ["prediction_timedelta", "variable", "rank"] + level_dims = ["prediction_timedelta", "variable", "level", "rank"] + + timedeltas = [ + timedelta(hours=(i + 1) * lead_time_hours) for i in range(rollout_iterations) + ] + surface_coords = [timedeltas, surface_variables, ranks] + level_coords = [timedeltas, level_variables, pressure_levels, ranks] else: - surface_data_shape = (len(surface_variables), 1) + surface_data_shape = (len(surface_variables),) level_data_shape = (len(level_variables), len(pressure_levels)) - # Variable indices include quantile (var, lev) --> (var, lev, histogram bin number). - # Enable LabelDictWrapper to extract metrics properly from RankHistogram output. - def _add_bin_index(variable_indices): - out = {} - for var, var_lev_idx in variable_indices.items(): - for bin_idx in range(n_members + 1): - out[f"{var}_{bin_idx + 1}"] = (*var_lev_idx, bin_idx) - return out + surface_dims = ["variable", "rank"] + level_dims = ["variable", "level", "rank"] + surface_coords = [surface_variables, ranks] + level_coords = [level_variables, pressure_levels, ranks] # Initialize separate metrics for level vars and surface vars. kwargs = {} if surface_variables: - kwargs["surface"] = LabelDictWrapper( - RankHistogram(data_shape=surface_data_shape, n_members=n_members), - variable_indices=add_timedelta_index( - _add_bin_index(era5.get_surface_variable_indices(surface_variables)), - lead_time_hours=lead_time_hours, - rollout_iterations=rollout_iterations, + kwargs["surface"] = LabelXarrayWrapper( + RankHistogram( + data_shape=surface_data_shape, + n_members=n_members, + preprocess=lambda x: x.squeeze(-3), ), - return_raw_dict=return_raw_dict, + dims=surface_dims, + coords=surface_coords, ) if level_variables: - kwargs["level"] = LabelDictWrapper( + kwargs["level"] = LabelXarrayWrapper( RankHistogram(data_shape=level_data_shape, n_members=n_members), - variable_indices=add_timedelta_index( - _add_bin_index( - era5.get_headline_level_variable_indices(pressure_levels, level_variables) - ), - lead_time_hours=lead_time_hours, - rollout_iterations=rollout_iterations, - ), - return_raw_dict=return_raw_dict, + dims=level_dims, + coords=level_coords, ) super().__init__(**kwargs) diff --git a/geoarches/metrics/spherical_power_spectrum.py b/geoarches/metrics/spherical_power_spectrum.py index bc32f36..c2ef102 100644 --- a/geoarches/metrics/spherical_power_spectrum.py +++ b/geoarches/metrics/spherical_power_spectrum.py @@ -24,11 +24,15 @@ class PowerSpectrum(Metric): """ Calculate spherical power spectrum on both targets and preds separately. + Compute power spectrum on each latlon grid. + Averages power spectrum over batch and members. + Accepted tensor shapes: targets: (batch, ..., lat, lon) preds: (batch, nmembers, ..., lat, lon) - Averages over batch and members. + Return: + metric will have shape (..., degree) where degree is the spherical harmonic degree l. """ def __init__( @@ -114,7 +118,6 @@ def __init__( pressure_levels: str = era5.pressure_levels, lead_time_hours: None | int = None, rollout_iterations: None | int = None, - return_raw_dict: bool = False, ): """ Args: @@ -126,19 +129,20 @@ def __init__( lead_time_hours: timedelta (in hours) between prediction times. rollout_iterations: number of multistep rollout for predictions. (ie. lead time of 24 hours for 3 days, lead_time_hours=24, rollout_iterations=3) - return_raw_dict: Whether to also return the raw output from the metrics. """ # Whether to include prediction_timdelta dimension. if rollout_iterations: - surface_coord_names = ["prediction_timedelta", "variable", "degree"] - level_coord_names = ["prediction_timedelta", "variable", "level", "degree"] + surface_dims = ["prediction_timedelta", "variable", "degree"] + level_dims = ["prediction_timedelta", "variable", "level", "degree"] - timedeltas = [timedelta((i + 1) * lead_time_hours) for i in range(rollout_iterations)] + timedeltas = [ + timedelta(hours=(i + 1) * lead_time_hours) for i in range(rollout_iterations) + ] surface_coords = [timedeltas, surface_variables] level_coords = [timedeltas, level_variables, pressure_levels] else: - surface_coord_names = ["variable", "degree"] - level_coord_names = ["variable", "level", "degree"] + surface_dims = ["variable", "degree"] + level_dims = ["variable", "level", "degree"] surface_coords = [surface_variables] level_coords = [level_variables, pressure_levels] @@ -146,16 +150,21 @@ def __init__( kwargs = {} if surface_variables: kwargs["surface"] = LabelXarrayWrapper( - PowerSpectrum(preprocess=lambda x: _remove_south_pole_lat(x.squeeze(-3))), - coord_names=surface_coord_names, + PowerSpectrum( + # Remove level dim and remove south pole latitude. + preprocess=lambda x: _remove_south_pole_lat(x.squeeze(-3)), + compute_target_spectrum=compute_target_spectrum, + ), + dims=surface_dims, coords=surface_coords, - return_raw_dict=return_raw_dict, ) if level_variables: kwargs["level"] = LabelXarrayWrapper( - PowerSpectrum(preprocess=_remove_south_pole_lat), - coord_names=level_coord_names, + PowerSpectrum( + preprocess=_remove_south_pole_lat, + compute_target_spectrum=compute_target_spectrum, + ), + dims=level_dims, coords=level_coords, - return_raw_dict=return_raw_dict, ) super().__init__(**kwargs) diff --git a/mkdocs.yaml b/mkdocs.yaml index 1dd84ca..e208e30 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -48,4 +48,5 @@ nav: - Setup: archesweather/setup.md - Run: archesweather/run.ipynb - Train: archesweather/train.md + - Evaluate: archesweather/eval.md diff --git a/poetry.lock b/poetry.lock index 1d67ca3..27da9a5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4555,6 +4555,29 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodest doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.16.5)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.0.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)"] test = ["Cython", "array-api-strict (>=2.0,<2.1.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "seaborn" +version = "0.13.2" +description = "Statistical data visualization" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version == \"3.11\" or python_version >= \"3.12\"" +files = [ + {file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"}, + {file = "seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"}, +] + +[package.dependencies] +matplotlib = ">=3.4,<3.6.1 || >3.6.1" +numpy = ">=1.20,<1.24.0 || >1.24.0" +pandas = ">=1.2" + +[package.extras] +dev = ["flake8", "flit", "mypy", "pandas-stubs", "pre-commit", "pytest", "pytest-cov", "pytest-xdist"] +docs = ["ipykernel", "nbconvert", "numpydoc", "pydata_sphinx_theme (==0.10.0rc2)", "pyyaml", "sphinx (<6.0.0)", "sphinx-copybutton", "sphinx-design", "sphinx-issues"] +stats = ["scipy (>=1.7)", "statsmodels (>=0.12)"] + [[package]] name = "sentry-sdk" version = "2.20.0" @@ -5484,4 +5507,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.11" -content-hash = "e13ab734903848e55aebf5563f4799704981d45001e295ea73bc55d564403778" +content-hash = "b4138a1a32cb328ad7e1c2fe2811315ccae282e8c357665726c1f0379530f67b" diff --git a/pyproject.toml b/pyproject.toml index 6b14a74..fa70e55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ dependencies = [ "pyshtools (>=4.13.1,<5.0.0)", "fasteners (>=0.19,<0.20)", "timm (>=1.0.13,<2.0.0)", - "dask (>=2024.12.1,<2025.0.0)" + "dask (>=2024.12.1,<2025.0.0)", + "seaborn (>=0.13.2,<0.14.0)", ] diff --git a/tests/metrics/test_label_wrapper.py b/tests/metrics/test_label_wrapper.py index e648f31..c3c5b72 100644 --- a/tests/metrics/test_label_wrapper.py +++ b/tests/metrics/test_label_wrapper.py @@ -35,7 +35,7 @@ def variable_indices(): } -class TestVarLevLabel: +class TestVarLevDims: def test_convert_to_labeled_dict(self, mock_metric, variable_indices): # Test compute method with labeled dict output wrapper = LabelDictWrapper( @@ -51,11 +51,11 @@ def test_convert_to_labeled_dict(self, mock_metric, variable_indices): torch.testing.assert_close(output["mae_var1"], torch.tensor(0.3)) torch.testing.assert_close(output["mae_var2"], torch.tensor(0.8)) - def test_convert_to_xarray(self, mock_metric, variable_indices): + def test_convert_to_xarray(self, mock_metric): # Test compute method with labeled dict output wrapper = LabelXarrayWrapper( metric=mock_metric, - coord_names=["variable", "level"], + dims=["variable", "level"], coords=[ ["var1", "var2"], [1], @@ -69,17 +69,17 @@ def test_convert_to_xarray(self, mock_metric, variable_indices): output, xr.Dataset( data_vars={ - "rmse": xr.DataArray( - data=np.array([[0.5], [1.0]], dtype=np.float32), - dims=("variable", "level"), + "var1": xr.DataArray( + data=np.array([[0.5], [0.3]], dtype=np.float32), + dims=("metric", "level"), ), - "mae": xr.DataArray( - data=np.array([[0.3], [0.8]], dtype=np.float32), - dims=("variable", "level"), + "var2": xr.DataArray( + data=np.array([[1.0], [0.8]], dtype=np.float32), + dims=("metric", "level"), ), }, coords={ - "variable": ["var1", "var2"], + "metric": ["rmse", "mae"], "level": [1], }, ), @@ -108,7 +108,7 @@ def mock_metric_with_timedelta_dimension(): return mock_metric -class TestTimeDeltaLabel: +class TestTimeDeltaDim: def test_convert_to_labeled_dict(self, mock_metric_with_timedelta_dimension, variable_indices): # Test compute method with labeled dict output wrapper = LabelDictWrapper( @@ -147,6 +147,52 @@ def test_convert_to_labeled_dict_with_explicit_timedelta_dimension( torch.testing.assert_close(output["rmse_var2_12h"], torch.tensor(0.7)) torch.testing.assert_close(output["rmse_var2_18h"], torch.tensor(1.1)) + def test_convert_to_xarray(self, mock_metric_with_timedelta_dimension): + # Test compute method with labeled dict output + wrapper = LabelXarrayWrapper( + metric=mock_metric_with_timedelta_dimension, + dims=["prediction_timedelta", "variable", "level"], + coords=[ + [timedelta(hours=6), timedelta(hours=12), timedelta(hours=18)], + ["var1", "var2"], + [1, 2], + ], + ) + + wrapper.update() + output = wrapper.compute() + + xr.testing.assert_equal( + output, + xr.Dataset( + data_vars={ + "var1": xr.DataArray( + data=np.array( + [[[0.1, 0.2], [0.5, 0.6], [0.9, 1.0]]], + dtype=np.float32, + ), + dims=("metric", "prediction_timedelta", "level"), + ), + "var2": xr.DataArray( + data=np.array( + [[[0.3, 0.4], [0.7, 0.8], [1.1, 1.2]]], + dtype=np.float32, + ), + dims=("metric", "prediction_timedelta", "level"), + ), + }, + coords={ + "metric": ["rmse"], + "level": [1, 2], + "prediction_timedelta": [ + timedelta(hours=6), + timedelta(hours=12), + timedelta(hours=18), + ], + }, + ), + ) + def test_convert_metric_dict_to_xarray(): labeled_dict = { @@ -170,17 +216,17 @@ def test_convert_metric_dict_to_xarray(): xr_dataset, xr.Dataset( data_vars={ - "mse": xr.DataArray( - data=np.array([[1, 2], [5, 6]], dtype=np.float32), - dims=("variable", "prediction_timedelta"), + "T2m": xr.DataArray( + data=np.array([[1, 2], [3, 4]], dtype=np.float32), + dims=("metric", "prediction_timedelta"), ), - "var": xr.DataArray( - data=np.array([[3, 4], [7, 8]], dtype=np.float32), - dims=("variable", "prediction_timedelta"), + "U10": xr.DataArray( + data=np.array([[5, 6], [7, 8]], dtype=np.float32), + dims=("metric", "prediction_timedelta"), ), }, coords={ - "variable": ["T2m", "U10"], + "metric": ["mse", "var"], "prediction_timedelta": [ timedelta(hours=24), timedelta(hours=48), @@ -207,13 +253,13 @@ def test_convert_metric_dict_to_xarray_with_bins_dimension(): xr_dataset, xr.Dataset( data_vars={ - "rankhist": xr.DataArray( + "T2m": xr.DataArray( data=np.array([[[1, 2], [3, 4]]], dtype=np.float32), - dims=("variable", "bins", "prediction_timedelta"), + dims=("metric", "bins", "prediction_timedelta"), ), }, coords={ - "variable": ["T2m"], + "metric": ["rankhist"], "prediction_timedelta": [timedelta(hours=24), timedelta(hours=48)], "bins": [1, 2], }, @@ -237,17 +283,17 @@ def test_convert_metric_dict_to_xarray_without_timedelta_dimension(): xr_dataset, xr.Dataset( data_vars={ - "mse": xr.DataArray( - data=np.array([1, 5], dtype=np.float32), - dims=("variable"), + "T2m": xr.DataArray( + data=np.array([1, 3], dtype=np.float32), + dims=("metric"), ), - "var": xr.DataArray( - data=np.array([3, 7], dtype=np.float32), - dims=("variable"), + "U10": xr.DataArray( + data=np.array([5, 7], dtype=np.float32), + dims=("metric"), ), }, coords={ - "variable": ["T2m", "U10"], + "metric": ["mse", "var"], }, ), ) diff --git a/tests/metrics/test_rank_histogram.py b/tests/metrics/test_rank_histogram.py index 8bfc469..c84555f 100644 --- a/tests/metrics/test_rank_histogram.py +++ b/tests/metrics/test_rank_histogram.py @@ -118,7 +118,31 @@ def test_handle_ties(self): class TestEra5RankHistogram: - def test_output_keys(self): + def test_output_dimensions(self): + bs, mem, lev, lat, lon = 2, 3, 3, 121, 240 + metric = Era5RankHistogram( + n_members=mem, + level_variables=["geopotential", "temperature"], + pressure_levels=[500, 700, 850], + ) + preds = { + "surface": torch.randn(bs, mem, 4, 1, lat, lon), + "level": torch.randn(bs, mem, 2, lev, lat, lon), + } + targets = { + "surface": torch.randn(bs, 4, 1, lat, lon), + "level": torch.randn(bs, 2, lev, lat, lon), + } + + metric.update(targets, preds) + metric.update(targets, preds) + + output_xarray = metric.compute() + + for coord in ["metric", "level", "rank"]: + assert coord in output_xarray.coords + + def test_output_dimensions_with_timdelta_dimension(self): bs, mem, timedelta, lev, lat, lon = 2, 3, 5, 3, 121, 240 metric = Era5RankHistogram( n_members=mem, @@ -139,29 +163,7 @@ def test_output_keys(self): metric.update(targets, preds) metric.update(targets, preds) - output = metric.compute() - expected_metric_keys = [ # All expected keys for final metric for one variable. - "rankhist_U10m_1_6h", - "rankhist_U10m_1_12h", - "rankhist_U10m_1_18h", - "rankhist_U10m_1_24h", - "rankhist_U10m_1_30h", - "rankhist_U10m_2_6h", - "rankhist_U10m_2_12h", - "rankhist_U10m_2_18h", - "rankhist_U10m_2_24h", - "rankhist_U10m_2_30h", - "rankhist_U10m_3_6h", - "rankhist_U10m_3_12h", - "rankhist_U10m_3_18h", - "rankhist_U10m_3_24h", - "rankhist_U10m_3_30h", - "rankhist_U10m_4_6h", - "rankhist_U10m_4_12h", - "rankhist_U10m_4_18h", - "rankhist_U10m_4_24h", - "rankhist_U10m_4_30h", - ] - for expected_metric_key in expected_metric_keys: - assert expected_metric_key in output - assert output[expected_metric_key].numel() == 1 + output_xarray = metric.compute() + + for coord in ["metric", "level", "prediction_timedelta", "rank"]: + assert coord in output_xarray.coords diff --git a/tests/metrics/test_spherical_power_spectrum.py b/tests/metrics/test_spherical_power_spectrum.py index c739251..ba30387 100644 --- a/tests/metrics/test_spherical_power_spectrum.py +++ b/tests/metrics/test_spherical_power_spectrum.py @@ -38,7 +38,7 @@ def test_output_dimensions(self): output_xarray = metric.compute() - for coord in ["variable", "level"]: + for coord in ["metric", "level"]: assert coord in output_xarray.coords def test_output_dimensions_with_timdelta_dimension(self): @@ -63,5 +63,5 @@ def test_output_dimensions_with_timdelta_dimension(self): output_xarray = metric.compute() - for coord in ["prediction_timedelta", "variable", "level"]: + for coord in ["prediction_timedelta", "metric", "level"]: assert coord in output_xarray.coords