Skip to content

Commit

Permalink
Documment functionality to load predictions for residual training
Browse files Browse the repository at this point in the history
  • Loading branch information
14renus committed Jan 31, 2025
1 parent df3efe5 commit 66cd870
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
4 changes: 2 additions & 2 deletions geoarches/configs/module/archesweathergen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ module:
uncond_proba: 0.0

# what to learn
load_deterministic_model: modelstore/jzh-geoaw-m-seed0 # by default
learn_residual: pred
load_deterministic_model: modelstore/jzh-geoaw-m-seed0 # used to load preds if learn_residual="pred" and not loaded by dataloader from storage.
learn_residual: pred # "default" to just learn [next_state-state], "pred" to learn residual of deterministic model.
conditional: prev+det # or prev+det for instance

# scales
Expand Down
12 changes: 6 additions & 6 deletions geoarches/dataloaders/era5.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def __init__(
"""
Args:
path: Single filepath or directory holding files.
domain: Specify data split for the filename filters (eg. train, val, test, testz0012..).
domain: Specify data split for the default filename filters (eg. train, val, test, testz0012..).
Used if `filename_filter` is None.
filename_filter: To filter files within `path` based on filename.
By default, filters files based on `domain`.
filename_filter: To filter files within `path` based on filename. If set, does not use `domain` param.
If None, filters files based on `domain`.
variables: Variables to load from dataset. Dict holding variable lists mapped by their keys to be processed into tensordict.
e.g. {surface:[...], level:[...]}. By default uses standard 6 level and 4 surface vars.
dimension_indexers: Dict of dimensions to select using Dataset.sel(dimension_indexers).
Expand Down Expand Up @@ -249,9 +249,9 @@ def __init__(
"""
Args:
path: Single filepath or directory holding files.
domain: Specify data split for the filename filters (eg. train, val, test, testz0012..)
filename_filter: To filter files within `path` based on filename.
By default, filters files based on `domain`.
domain: Specify data split for the default filename filters (eg. train, val, test, testz0012..)
filename_filter: To filter files within `path` based on filename. If set, does not use `domain` param.
If None, filters files based on `domain`.
lead_time_hours: Time difference between current state and previous and future states.
multistep: Number of future states to load. By default, loads next state only (current time + lead_time_hours).
load_prev: Whether to load state at previous timestamp (current time - lead_time_hours).
Expand Down
16 changes: 15 additions & 1 deletion geoarches/dataloaders/era5pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ def __init__(
variables=None,
**kwargs,
):
"""
Args:
path: Single filepath or directory holding groundtruth files.
domain: Specify data split for the default filename filters (eg. train, val, test, testz0012..).
filename_filter: To filter files within `path` based on filename. If set, does not use `domain` param.
If None, filters files based on `domain`.
lead_time_hours: Time difference between current state and previous and future states.
pred_path: Single filepath or directory holding model prediction files to also load.
load_prev: Whether to load state at previous timestamp (current time - lead_time_hours).
norm_scheme: Normalization scheme to use. Can be None to perform no normalization.
load_hard_neg: Whether to additionallty load hard negative example for contrastive learning.
variables: Variables to load from dataset. Dict holding variable lists mapped by their keys to be processed into tensordict.
e.g. {surface:[...], level:[...] By default uses standard 6 level and 4 surface vars.
"""
super().__init__(
path=path,
domain=domain,
Expand All @@ -35,7 +49,7 @@ def __init__(
self.load_hard_neg = load_hard_neg
# self.filename_filter is already init
if pred_path is not None:
self.pred_ds = netcdf.NetcdfDataset(
self.pred_ds = netcdf.XarrayDataset(
path=pred_path,
filename_filter=self.filename_filter,
variables=self.variables,
Expand Down

0 comments on commit 66cd870

Please sign in to comment.