diff --git a/geoarches/configs/module/archesweathergen.yaml b/geoarches/configs/module/archesweathergen.yaml index bb2126e..cf13336 100644 --- a/geoarches/configs/module/archesweathergen.yaml +++ b/geoarches/configs/module/archesweathergen.yaml @@ -12,8 +12,8 @@ module: uncond_proba: 0.0 # what to learn - load_deterministic_model: modelstore/jzh-geoaw-m-seed0 # by default - learn_residual: pred + load_deterministic_model: modelstore/jzh-geoaw-m-seed0 # used to load preds if learn_residual="pred" and not loaded by dataloader from storage. + learn_residual: pred # "default" to just learn [next_state-state], "pred" to learn residual of deterministic model. conditional: prev+det # or prev+det for instance # scales diff --git a/geoarches/dataloaders/era5.py b/geoarches/dataloaders/era5.py index 08b1d83..2fb393b 100644 --- a/geoarches/dataloaders/era5.py +++ b/geoarches/dataloaders/era5.py @@ -105,10 +105,10 @@ def __init__( """ Args: path: Single filepath or directory holding files. - domain: Specify data split for the filename filters (eg. train, val, test, testz0012..). + domain: Specify data split for the default filename filters (eg. train, val, test, testz0012..). Used if `filename_filter` is None. - filename_filter: To filter files within `path` based on filename. - By default, filters files based on `domain`. + filename_filter: To filter files within `path` based on filename. If set, does not use `domain` param. + If None, filters files based on `domain`. variables: Variables to load from dataset. Dict holding variable lists mapped by their keys to be processed into tensordict. e.g. {surface:[...], level:[...]}. By default uses standard 6 level and 4 surface vars. dimension_indexers: Dict of dimensions to select using Dataset.sel(dimension_indexers). @@ -249,9 +249,9 @@ def __init__( """ Args: path: Single filepath or directory holding files. - domain: Specify data split for the filename filters (eg. train, val, test, testz0012..) - filename_filter: To filter files within `path` based on filename. - By default, filters files based on `domain`. + domain: Specify data split for the default filename filters (eg. train, val, test, testz0012..) + filename_filter: To filter files within `path` based on filename. If set, does not use `domain` param. + If None, filters files based on `domain`. lead_time_hours: Time difference between current state and previous and future states. multistep: Number of future states to load. By default, loads next state only (current time + lead_time_hours). load_prev: Whether to load state at previous timestamp (current time - lead_time_hours). diff --git a/geoarches/dataloaders/era5pred.py b/geoarches/dataloaders/era5pred.py index 73e20d7..857f201 100644 --- a/geoarches/dataloaders/era5pred.py +++ b/geoarches/dataloaders/era5pred.py @@ -22,6 +22,20 @@ def __init__( variables=None, **kwargs, ): + """ + Args: + path: Single filepath or directory holding groundtruth files. + domain: Specify data split for the default filename filters (eg. train, val, test, testz0012..). + filename_filter: To filter files within `path` based on filename. If set, does not use `domain` param. + If None, filters files based on `domain`. + lead_time_hours: Time difference between current state and previous and future states. + pred_path: Single filepath or directory holding model prediction files to also load. + load_prev: Whether to load state at previous timestamp (current time - lead_time_hours). + norm_scheme: Normalization scheme to use. Can be None to perform no normalization. + load_hard_neg: Whether to additionallty load hard negative example for contrastive learning. + variables: Variables to load from dataset. Dict holding variable lists mapped by their keys to be processed into tensordict. + e.g. {surface:[...], level:[...] By default uses standard 6 level and 4 surface vars. + """ super().__init__( path=path, domain=domain, @@ -35,7 +49,7 @@ def __init__( self.load_hard_neg = load_hard_neg # self.filename_filter is already init if pred_path is not None: - self.pred_ds = netcdf.NetcdfDataset( + self.pred_ds = netcdf.XarrayDataset( path=pred_path, filename_filter=self.filename_filter, variables=self.variables,