Skip to content

Commit

Permalink
lin reg: add exclude to predict
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause committed Dec 19, 2023
1 parent 616c8f2 commit f31247f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
9 changes: 7 additions & 2 deletions mesmer/stats/_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import xarray as xr

from mesmer.core.utils import _check_dataarray_form, _check_dataset_form
from mesmer.core.utils import _check_dataarray_form, _check_dataset_form, _to_set


class LinearRegression:
Expand Down Expand Up @@ -53,6 +53,7 @@ def fit(
def predict(
self,
predictors: Mapping[str, xr.DataArray],
exclude=None,
):
"""
Predict using the linear model.
Expand All @@ -61,6 +62,8 @@ def predict(
----------
predictors : dict of xr.DataArray
A dict of DataArray objects used as predictors. Must be 1D and contain `dim`.
exclude : str or set of str, default: None
Set of variables to exclude in the prediction.
Returns
-------
Expand All @@ -70,8 +73,10 @@ def predict(

params = self.params

exclude = _to_set(exclude)

non_predictor_vars = {"intercept", "weights", "fit_intercept"}
required_predictors = set(params.data_vars) - non_predictor_vars
required_predictors = set(params.data_vars) - non_predictor_vars - exclude
available_predictors = set(predictors.keys())

if required_predictors != available_predictors:
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,44 @@ def test_LR_predict(as_2D):
xr.testing.assert_equal(result, expected)


@pytest.mark.parametrize("as_2D", [True, False])
def test_lr_predict_exclude(as_2D):
lr = mesmer.stats.LinearRegression()

params = xr.Dataset(
data_vars={
"intercept": ("x", [5]),
"fit_intercept": True,
"tas": ("x", [3]),
"tas2": ("x", [1]),
}
)
lr.params = params if as_2D else params.squeeze()

tas = xr.DataArray([0, 1, 2], dims="time")

with pytest.raises(ValueError, match="Missing or superfluous predictors"):
lr.predict({"tas": tas})

result = lr.predict({"tas": tas}, exclude="tas2")
expected = xr.DataArray([[5, 8, 11]], dims=("x", "time"))
expected = expected if as_2D else expected.squeeze()

xr.testing.assert_equal(result, expected)

result = lr.predict({"tas": tas}, exclude={"tas2"})
expected = xr.DataArray([[5, 8, 11]], dims=("x", "time"))
expected = expected if as_2D else expected.squeeze()

xr.testing.assert_equal(result, expected)

result = lr.predict({}, exclude={"tas", "tas2"})
expected = xr.DataArray([5], dims="x")
expected = expected if as_2D else expected.squeeze()

xr.testing.assert_equal(result, expected)


@pytest.mark.parametrize("as_2D", [True, False])
def test_LR_residuals(as_2D):

Expand Down

0 comments on commit f31247f

Please sign in to comment.