Skip to content

Commit

Permalink
refactor concat_hist_future (#241)
Browse files Browse the repository at this point in the history
* refactor concat_hist_future

* Apply suggestions from code review
  • Loading branch information
mathause authored Jan 19, 2023
1 parent 7a822c9 commit f5b36fb
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 48 deletions.
31 changes: 10 additions & 21 deletions mesmer/create_emulations/create_emus_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
Functions to create global trend emulations with MESMER.
"""

import numpy as np

from mesmer.create_emulations.utils import _concatenate_hist_future
from mesmer.io.save_mesmer_bundle import save_mesmer_data


Expand Down Expand Up @@ -67,33 +67,22 @@ def create_emus_gt(params_gt, preds_gt, cfg, concat_h_f=False, save_emus=True):
pred_names = list(preds_gt.keys())
scenarios_emus = list(preds_gt[pred_names[0]].keys())

scens_out_f = list(map(lambda x: x.replace("h-", ""), scenarios_emus))
# does nothing in case 'h-' not actually included
if "h-" in scenarios_emus[0]:
scenarios_emus = ["hist"] + [scen.replace("h-", "") for scen in scenarios_emus]

if concat_h_f:
scens_out = scenarios_emus
else:
if "h-" in scenarios_emus[0]:
scens_out = ["hist"] + scens_out_f
else:
scens_out = scens_out_f

# initialize global trend emulations dictionary with scenarios as keys
emus_gt = {}

# apply the chosen method
if "LOWESS" in params_gt["method"]:
if concat_h_f:
for scen_out, scen_out_f in zip(scens_out, scens_out_f):
emus_gt[scen_out] = np.concatenate(
[params_gt["hist"], params_gt[scen_out_f]]
)
else:
for scen_out in scens_out:
emus_gt[scen_out] = params_gt[scen_out]
for scen in scenarios_emus:
emus_gt[scen] = params_gt[scen]
else:
raise ValueError("The chosen method is currently not implemented.")

if concat_h_f:
emus_gt = _concatenate_hist_future(emus_gt)
scenarios_emus = list(emus_gt.keys())

# save the global trend emulation if requested
if save_emus:
save_mesmer_data(
Expand All @@ -107,7 +96,7 @@ def create_emus_gt(params_gt, preds_gt, cfg, concat_h_f=False, save_emus=True):
*params_gt["preds"],
params_gt["targ"],
params_gt["esm"],
*scens_out,
*scenarios_emus,
],
)

Expand Down
41 changes: 14 additions & 27 deletions mesmer/create_emulations/create_emus_lt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
"""


import numpy as np

import mesmer.stats
from mesmer.create_emulations.utils import _gather_params, _gather_preds
from mesmer.create_emulations.utils import (
_concatenate_hist_future,
_gather_params,
_gather_preds,
)
from mesmer.io.save_mesmer_bundle import save_mesmer_data


Expand Down Expand Up @@ -77,18 +79,8 @@ def create_emus_lt(params_lt, preds_lt, cfg, concat_h_f=False, save_emus=True):
pred_names = list(preds_lt.keys())
scenarios_emus = list(preds_lt[pred_names[0]].keys())

if concat_h_f:
if scenarios_emus[0] == "hist":
scens_out_f = scenarios_emus[1:]
scens_out = ["h-" + s for s in scens_out_f]
else:
raise ValueError("This combination is not supported.")
else:
if "h-" in scenarios_emus[0]:
scens_out_f = list(map(lambda x: x.replace("h-", ""), scenarios_emus))
scens_out = ["hist"] + scens_out_f
else:
scens_out = scens_out_f = scenarios_emus
if "h-" in scenarios_emus[0]:
scenarios_emus = ["hist"] + [scen.replace("h-", "") for scen in scenarios_emus]

# check if correct predictors
if pred_names != params_lt["preds"]:
Expand All @@ -112,18 +104,13 @@ def create_emus_lt(params_lt, preds_lt, cfg, concat_h_f=False, save_emus=True):

# create emulations
emus_lt = {}

for scen in scenarios_emus:
emus_lt[scen] = create_emus_method_lt(params_lt, preds_lt, scen)

if concat_h_f:
lt_hist = create_emus_method_lt(params_lt, preds_lt, "hist")
for scen_out, scen_out_f in zip(scens_out, scens_out_f):
lt_scen_f = create_emus_method_lt(params_lt, preds_lt, scen_out_f)
emus_lt[scen_out] = {}
for targ in params_lt["targs"]:
emus_lt[scen_out][targ] = np.concatenate(
[lt_hist[targ], lt_scen_f[targ]]
)
else:
for scen_out in scens_out:
emus_lt[scen_out] = create_emus_method_lt(params_lt, preds_lt, scen_out)
emus_lt = _concatenate_hist_future(emus_lt)
scenarios_emus = list(emus_lt.keys())

# save the local trends emulation if requested
if save_emus:
Expand All @@ -138,7 +125,7 @@ def create_emus_lt(params_lt, preds_lt, cfg, concat_h_f=False, save_emus=True):
*params_lt["preds"],
*params_lt["targs"],
params_lt["esm"],
*scens_out,
*scenarios_emus,
],
)

Expand Down
48 changes: 48 additions & 0 deletions mesmer/create_emulations/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,54 @@
import numpy as np
import xarray as xr


def _concatenate_hist_future(data):
"""concatenate historical and future data
Parameters
----------
data : dict
Possibly nested dictionary containing arrays to concatenate. The keys of data
must correspond to the scenarios to use. The values can either be numpy arrays
or dicts of numpy arrays.
Returns
-------
concatenated : dict
Possibly nested dictionary with concatenated arrays.
"""

scens_in = list(data.keys())

if "hist" not in scens_in:
raise ValueError("data does not contain 'hist' scenario")

scens_in.remove("hist")
scens_out = [f"h-{scen}" for scen in scens_in]

concatenated = {}

hist = data.pop("hist")

# data is a nested dict
if isinstance(hist, dict):

for scen_out, scen_in in zip(scens_out, scens_in):
concatenated[scen_out] = {}
for targ in data[scen_in].keys():
concatenated[scen_out][targ] = np.concatenate(
[hist[targ], data[scen_in][targ]]
)

# data is not a nested dict
else:

for scen_out, scen_in in zip(scens_out, scens_in):
concatenated[scen_out] = np.concatenate([hist, data[scen_in]])

return concatenated


def _gather_preds(preds_dict, predictor_names, scen, dims):
"""gather predictors for linear regression from legacy data structures
Expand Down

0 comments on commit f5b36fb

Please sign in to comment.