Skip to content

Commit

Permalink
xarray output for power spec and rank hist metrics (#15)
Browse files Browse the repository at this point in the history
* xarray wrapper creates dataset with var as data vars

instead of metric as data vars
to support variables with diff dimensions (ie. surface variables only have 1 level)

* convert to xarray also creates xarray with vars as data vars

* plot reads new format of xarray

* remove option to return raw dict

confusing and not used
rather return metrics in formatted xarray anyways

* fix convert metric dict with metadata

* plot rankhist like paper

* migrate rankhist to output xarray not labeled dict

* fix timedelta in era5 metrics

* handle extra dimensions in xarray when plotting

* plot option color per model

* fix plot var titles

* fix rankhist check in plot.py

* add option to filter files in pred_path for eval

* convert to xarray dict reports readable error

* document eval commands used for rankhist
  • Loading branch information
14renus authored Jan 28, 2025
1 parent 8efe10d commit df3efe5
Show file tree
Hide file tree
Showing 14 changed files with 437 additions and 205 deletions.
32 changes: 32 additions & 0 deletions docs/archesweather/eval.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Evaluate

Set model run name (used in hydra argument `++name=NAME`):
```sh
NAME=archesweathergen
```

## Commands to compute metrics

### Rank histogram
```sh
python -m geoarches.evaluation.eval_multistep \
--pred_path evalstore/${NAME}/ \
--output_dir evalstore/${NAME}_metrics/ \
--groundtruth_path data/era5_240/full/ \
--multistep 10 --num_workers 4 \
--metrics era5_rank_histogram_50_members
```

## Commands to plot (WIP)

### Rank histogram
```sh
python -m geoarches.evaluation.plot --output_dir plots/
--metric_paths evalstore/${NAME}_metrics/test-multistep=10-era5_rank_histogram_50_members.nc
--model_names ArchesWeatherGen \
--model_colors red \
--metrics rankhist \
--vars Z500:geopotential:level:500 T850:temperature:level:850 Q700:specific_humidity:level:700 U850:u_component_of_wind:level:850 V850:v_component_of_wind:level:850 \
--rankhist_prediction_timedeltas 1 7 \
--figsize 10 4
```
29 changes: 17 additions & 12 deletions geoarches/evaluation/eval_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def main():
required=True,
help="Directory or file path to find model predictions.",
)
parser.add_argument(
"--pred_filename_filter",
nargs="*", # Accepts 0 or more arguments as a list.
type=str,
help="Substring(s) in filenames under --pred_path to keep files to run inference on.",
)
parser.add_argument(
"--groundtruth_path",
type=str,
Expand Down Expand Up @@ -129,8 +135,6 @@ def main():
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Reading from predictions path:", args.pred_path)

# Output directory to save evaluation.
output_dir = args.output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -164,10 +168,16 @@ def main():
print(f"Reading {len(ds_test.files)} files from groundtruth path: {args.groundtruth_path}.")

# Predictions.
def _pred_filename_filter(filename):
for substring in args.pred_filename_filter:
if substring not in filename:
return False
return True

if not args.eval_clim:
ds_pred = era5.Era5Dataset(
path=args.pred_path,
filename_filter=(lambda x: True), # Update filename_filter to filter within pred_path.
filename_filter=_pred_filename_filter, # Update filename_filter to filter within pred_path.
variables=variables,
return_timestamp=True,
dimension_indexers=dict(
Expand Down Expand Up @@ -220,7 +230,6 @@ def __getitem__(self, idx):
pressure_levels=[500, 700, 850],
lead_time_hours=24 if args.multistep else None,
rollout_iterations=args.multistep,
return_raw_dict=True,
).to(device)
print(f"Computing: {metrics.keys()}")

Expand Down Expand Up @@ -256,7 +265,7 @@ def __getitem__(self, idx):
metric.update(target.to(device), pred.to(device))

for metric_name, metric in metrics.items():
raw_dict, labelled_metric_output = metric.compute()
labelled_metric_output = metric.compute()

if Path(args.pred_path).is_file():
output_filename = f"{Path(args.pred_path).stem}-{metric_name}"
Expand All @@ -276,19 +285,15 @@ def __getitem__(self, idx):
ds = convert_metric_dict_to_xarray(labelled_dict, extra_dimensions)

# Write labeled dict.
labelled_dict["groundtruth_path"] = args.groundtruth_path
labelled_dict["predictions_path"] = args.pred_path
labelled_dict["metadata"] = dict(
groundtruth_path=args.groundtruth_path, predictions_path=args.pred_path
)
torch.save(labelled_dict, Path(output_dir).joinpath(f"{output_filename}.pt"))
else:
ds = labelled_metric_output
# Write xr dataset.
ds.to_netcdf(Path(output_dir).joinpath(f"{output_filename}.nc"))

# Write raw score dict.
raw_dict["groundtruth_path"] = args.groundtruth_path
raw_dict["predictions_path"] = args.pred_path
torch.save(raw_dict, Path(output_dir).joinpath(f"{output_filename}-raw.pt"))


if __name__ == "__main__":
main()
Loading

0 comments on commit df3efe5

Please sign in to comment.