diff --git a/mesmer/create_emulations/create_emus_gt.py b/mesmer/create_emulations/create_emus_gt.py index cfec21d0..d46cb3d4 100644 --- a/mesmer/create_emulations/create_emus_gt.py +++ b/mesmer/create_emulations/create_emus_gt.py @@ -6,8 +6,8 @@ Functions to create global trend emulations with MESMER. """ -import numpy as np +from mesmer.create_emulations.utils import _concatenate_hist_future from mesmer.io.save_mesmer_bundle import save_mesmer_data @@ -67,33 +67,22 @@ def create_emus_gt(params_gt, preds_gt, cfg, concat_h_f=False, save_emus=True): pred_names = list(preds_gt.keys()) scenarios_emus = list(preds_gt[pred_names[0]].keys()) - scens_out_f = list(map(lambda x: x.replace("h-", ""), scenarios_emus)) - # does nothing in case 'h-' not actually included + if "h-" in scenarios_emus[0]: + scenarios_emus = ["hist"] + [scen.replace("h-", "") for scen in scenarios_emus] - if concat_h_f: - scens_out = scenarios_emus - else: - if "h-" in scenarios_emus[0]: - scens_out = ["hist"] + scens_out_f - else: - scens_out = scens_out_f - - # initialize global trend emulations dictionary with scenarios as keys emus_gt = {} # apply the chosen method if "LOWESS" in params_gt["method"]: - if concat_h_f: - for scen_out, scen_out_f in zip(scens_out, scens_out_f): - emus_gt[scen_out] = np.concatenate( - [params_gt["hist"], params_gt[scen_out_f]] - ) - else: - for scen_out in scens_out: - emus_gt[scen_out] = params_gt[scen_out] + for scen in scenarios_emus: + emus_gt[scen] = params_gt[scen] else: raise ValueError("The chosen method is currently not implemented.") + if concat_h_f: + emus_gt = _concatenate_hist_future(emus_gt) + scenarios_emus = list(emus_gt.keys()) + # save the global trend emulation if requested if save_emus: save_mesmer_data( @@ -107,7 +96,7 @@ def create_emus_gt(params_gt, preds_gt, cfg, concat_h_f=False, save_emus=True): *params_gt["preds"], params_gt["targ"], params_gt["esm"], - *scens_out, + *scenarios_emus, ], ) diff --git a/mesmer/create_emulations/create_emus_lt.py b/mesmer/create_emulations/create_emus_lt.py index adccf183..dffa82ff 100644 --- a/mesmer/create_emulations/create_emus_lt.py +++ b/mesmer/create_emulations/create_emus_lt.py @@ -7,10 +7,12 @@ """ -import numpy as np - import mesmer.stats -from mesmer.create_emulations.utils import _gather_params, _gather_preds +from mesmer.create_emulations.utils import ( + _concatenate_hist_future, + _gather_params, + _gather_preds, +) from mesmer.io.save_mesmer_bundle import save_mesmer_data @@ -77,18 +79,8 @@ def create_emus_lt(params_lt, preds_lt, cfg, concat_h_f=False, save_emus=True): pred_names = list(preds_lt.keys()) scenarios_emus = list(preds_lt[pred_names[0]].keys()) - if concat_h_f: - if scenarios_emus[0] == "hist": - scens_out_f = scenarios_emus[1:] - scens_out = ["h-" + s for s in scens_out_f] - else: - raise ValueError("This combination is not supported.") - else: - if "h-" in scenarios_emus[0]: - scens_out_f = list(map(lambda x: x.replace("h-", ""), scenarios_emus)) - scens_out = ["hist"] + scens_out_f - else: - scens_out = scens_out_f = scenarios_emus + if "h-" in scenarios_emus[0]: + scenarios_emus = ["hist"] + [scen.replace("h-", "") for scen in scenarios_emus] # check if correct predictors if pred_names != params_lt["preds"]: @@ -112,18 +104,13 @@ def create_emus_lt(params_lt, preds_lt, cfg, concat_h_f=False, save_emus=True): # create emulations emus_lt = {} + + for scen in scenarios_emus: + emus_lt[scen] = create_emus_method_lt(params_lt, preds_lt, scen) + if concat_h_f: - lt_hist = create_emus_method_lt(params_lt, preds_lt, "hist") - for scen_out, scen_out_f in zip(scens_out, scens_out_f): - lt_scen_f = create_emus_method_lt(params_lt, preds_lt, scen_out_f) - emus_lt[scen_out] = {} - for targ in params_lt["targs"]: - emus_lt[scen_out][targ] = np.concatenate( - [lt_hist[targ], lt_scen_f[targ]] - ) - else: - for scen_out in scens_out: - emus_lt[scen_out] = create_emus_method_lt(params_lt, preds_lt, scen_out) + emus_lt = _concatenate_hist_future(emus_lt) + scenarios_emus = list(emus_lt.keys()) # save the local trends emulation if requested if save_emus: @@ -138,7 +125,7 @@ def create_emus_lt(params_lt, preds_lt, cfg, concat_h_f=False, save_emus=True): *params_lt["preds"], *params_lt["targs"], params_lt["esm"], - *scens_out, + *scenarios_emus, ], ) diff --git a/mesmer/create_emulations/utils.py b/mesmer/create_emulations/utils.py index 530060a4..920293c7 100644 --- a/mesmer/create_emulations/utils.py +++ b/mesmer/create_emulations/utils.py @@ -1,6 +1,54 @@ +import numpy as np import xarray as xr +def _concatenate_hist_future(data): + """concatenate historical and future data + + Parameters + ---------- + data : dict + Possibly nested dictionary containing arrays to concatenate. The keys of data + must correspond to the scenarios to use. The values can either be numpy arrays + or dicts of numpy arrays. + + Returns + ------- + concatenated : dict + Possibly nested dictionary with concatenated arrays. + """ + + scens_in = list(data.keys()) + + if "hist" not in scens_in: + raise ValueError("data does not contain 'hist' scenario") + + scens_in.remove("hist") + scens_out = [f"h-{scen}" for scen in scens_in] + + concatenated = {} + + hist = data.pop("hist") + + # data is a nested dict + if isinstance(hist, dict): + + for scen_out, scen_in in zip(scens_out, scens_in): + concatenated[scen_out] = {} + for targ in data[scen_in].keys(): + concatenated[scen_out][targ] = np.concatenate( + [hist[targ], data[scen_in][targ]] + ) + + # data is not a nested dict + else: + + for scen_out, scen_in in zip(scens_out, scens_in): + concatenated[scen_out] = np.concatenate([hist, data[scen_in]]) + + return concatenated + + def _gather_preds(preds_dict, predictor_names, scen, dims): """gather predictors for linear regression from legacy data structures