Skip to content

Commit

Permalink
use LinearRegression.predict() internally (#240)
Browse files Browse the repository at this point in the history
* use LinearRegression.predict()

* remove stray print

* shorter name

* linting

* CHANGELOG

* fix changelog
  • Loading branch information
mathause authored Jan 13, 2023
1 parent 62a6274 commit 84d7e3c
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 28 deletions.
17 changes: 12 additions & 5 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,30 @@ Internal Changes

- Restore compatibility with regionmask v0.9.0 (`#136 <https://github.com/MESMER-group/mesmer/pull/136>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

- Renamed the ``interpolation`` keyword of ``np.quantile`` to ``method`` changed in
numpy v1.22.0 (`#137 <https://github.com/MESMER-group/mesmer/pull/137>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

- Make use of :py:class:`mesmer.stats.linear_regression.LinearRegression` in
:py:func:`mesmer.calibrate_mesmer.train_gt_ic_OLSVOLC` (`#145 <https://github.com/MESMER-group/mesmer/pull/145>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- :py:func:`mesmer.calibrate_mesmer.train_gt_ic_OLSVOLC` (`#145 <https://github.com/MESMER-group/mesmer/pull/145>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- :py:func:`mesmer.create_emulations.create_emus_lv_OLS` and :py:func:`mesmer.create_emulations.create_emus_OLS_each_gp_sep`
(`#240 <https://github.com/MESMER-group/mesmer/pull/240>`_).By `Mathias Hauser <https://github.com/mathause>`_.

- Add python 3.10 to list of supported versions (`#162 <https://github.com/MESMER-group/mesmer/pull/162>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

- Move contents of setup.py to setup.cfg (`#169 <https://github.com/MESMER-group/mesmer/pull/169>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

- Use pyproject.toml for the build-system and setuptools_scm for the `__version__`
(`#188 <https://github.com/MESMER-group/mesmer/pull/188>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- Added additional tests for the calibration step (`#209 <https://github.com/MESMER-group/mesmer/issues/209>`_):

- one scenario (SSP5-8.5) and two ensemble members (`#211 <https://github.com/MESMER-group/mesmer/pull/211>`_)
- two scenarios (SSP1-2.6 and SSP5-8.5) with one and two ensemble members, respectively (`#214 <https://github.com/MESMER-group/mesmer/pull/214>`_)
- Added additional tests for the calibration step (`#209 <https://github.com/MESMER-group/mesmer/issues/209>`_):
- one scenario (SSP5-8.5) and two ensemble members (`#211 <https://github.com/MESMER-group/mesmer/pull/211>`_)
- two scenarios (SSP1-2.6 and SSP5-8.5) with one and two ensemble members, respectively (`#214 <https://github.com/MESMER-group/mesmer/pull/214>`_)

By `Mathias Hauser <https://github.com/mathause>`_.

Expand Down
26 changes: 12 additions & 14 deletions mesmer/create_emulations/create_emus_lt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np

import mesmer.stats
from mesmer.create_emulations.utils import _gather_params, _gather_preds
from mesmer.io.save_mesmer_bundle import save_mesmer_data


Expand Down Expand Up @@ -185,21 +187,17 @@ def create_emus_OLS_each_gp_sep(params_lt, preds_lt, scen):
"""

pred_names = list(preds_lt.keys())
nr_ts = len(
preds_lt[pred_names[0]][scen]
) # nr_ts could vary for different scenarios but is the same for all predictors

emus_lt = {}
for targ in params_lt["targs"]:
nr_gps = len(params_lt["intercept"][targ])
emus_lt[targ] = np.zeros([nr_ts, nr_gps])
for gp in np.arange(nr_gps):
pred_vals = [
params_lt["coef_" + pred][targ][gp] * preds_lt[pred][scen]
for pred in params_lt["preds"]
]

emus_lt[targ][:, gp] = sum(pred_vals) + params_lt["intercept"][targ][gp]

params = _gather_params(params_lt, targ, dims="cell")
predictors = _gather_preds(preds_lt, params_lt["preds"], scen, dims="time")

lr = mesmer.stats.linear_regression.LinearRegression()
lr.params = params

prediction = lr.predict(predictors=predictors)

emus_lt[targ] = prediction.values.T

return emus_lt
22 changes: 13 additions & 9 deletions mesmer/create_emulations/create_emus_lv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np

import mesmer.stats
from mesmer.create_emulations.utils import _gather_params, _gather_preds
from mesmer.io.save_mesmer_bundle import save_mesmer_data
from mesmer.stats.auto_regression import _draw_auto_regression_correlated_np

Expand Down Expand Up @@ -247,14 +249,16 @@ def create_emus_lv_OLS(params_lv, preds_lv):
for scen in scens_OLS:
emus_lv[scen] = {}

preds = _gather_preds(preds_lv, params_lv["preds"], scen, dims=("scen", "time"))

for targ in params_lv["targs"]:
nr_emus_v, nr_ts_emus_v = preds_lv[pred_names[0]][scen].shape
nr_gps = len(params_lv["coef_" + params_lv["preds"][0]][targ])
emus_lv[scen][targ] = np.zeros([nr_emus_v, nr_ts_emus_v, nr_gps])
for run in np.arange(nr_emus_v):
for gp in np.arange(nr_gps):
emus_lv[scen][targ][run, :, gp] = sum(
params_lv["coef_" + pred][targ][gp] * preds_lv[pred][scen][run]
for pred in params_lv["preds"]
)

params = _gather_params(params_lv, targ, dims="gridpoint")

lr = mesmer.stats.linear_regression.LinearRegression()
lr.params = params
prediction = lr.predict(predictors=preds)

emus_lv[scen][targ] = prediction.values

return emus_lv
74 changes: 74 additions & 0 deletions mesmer/create_emulations/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import xarray as xr


def _gather_preds(preds_dict, predictor_names, scen, dims):
"""gather predictors for linear regression from legacy data structures
Parameters
----------
preds_dict : dict
Dictonary containg all predictors.
predictor_names : list of str
List of all predictors to gather from ``preds_dict``.
scen : str
Scenario for which to read the predictors.
dims : str, tuple of str
Name of string for DataArray
Returns
-------
predictors : dict
Dictonary of gathered predictors.
Notes
-----
This function should become obsolete once switching to the newer data structures.
"""
predictors = {}
for pred in predictor_names:
predictors[pred] = xr.DataArray(preds_dict[pred][scen], dims=dims)

return predictors


def _gather_params(params_dict, targ, dims):
"""gather parameters for linear regression from legacy data structures
Parameters
----------
params_dict : dict
Dictonary containg all parameters.
targ : str
Name of target variable for which to read the parameters.
dims : str, tuple of str
Name of string for DataArray
Returns
-------
params : xr.Dataset
Dataset of gathered parameters.
Notes
-----
This function should become obsolete once switching to the newer data structures.
"""

params = {}
for pred in params_dict["preds"]:

params[pred] = xr.DataArray(params_dict[f"coef_{pred}"][targ], dims=dims)

if "intercept" in params_dict:
intercept = xr.DataArray(params_dict["intercept"][targ], dims=dims)
fit_intercept = True
else:
intercept = 0
fit_intercept = False

params["intercept"] = intercept
params["fit_intercept"] = fit_intercept

return xr.Dataset(data_vars=params)

0 comments on commit 84d7e3c

Please sign in to comment.