Skip to content

Commit

Permalink
add option to filter files in pred_path for eval
Browse files Browse the repository at this point in the history
  • Loading branch information
14renus committed Jan 28, 2025
1 parent 14ea7bf commit de3914e
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions geoarches/evaluation/eval_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def main():
required=True,
help="Directory or file path to find model predictions.",
)
parser.add_argument(
"--pred_filename_filter",
nargs="*", # Accepts 0 or more arguments as a list.
type=str,
help="Substring(s) in filenames under --pred_path to keep files to run inference on.",
)
parser.add_argument(
"--groundtruth_path",
type=str,
Expand Down Expand Up @@ -129,8 +135,6 @@ def main():
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Reading from predictions path:", args.pred_path)

# Output directory to save evaluation.
output_dir = args.output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -164,10 +168,16 @@ def main():
print(f"Reading {len(ds_test.files)} files from groundtruth path: {args.groundtruth_path}.")

# Predictions.
def _pred_filename_filter(filename):
for substring in args.pred_filename_filter:
if substring not in filename:
return False
return True

if not args.eval_clim:
ds_pred = era5.Era5Dataset(
path=args.pred_path,
filename_filter=(lambda x: True), # Update filename_filter to filter within pred_path.
filename_filter=_pred_filename_filter, # Update filename_filter to filter within pred_path.
variables=variables,
return_timestamp=True,
dimension_indexers=dict(
Expand Down

0 comments on commit de3914e

Please sign in to comment.