From b58643919ecb6970e75c08afcce882af27b9e22f Mon Sep 17 00:00:00 2001 From: Zhishi Wang Date: Mon, 9 Nov 2020 14:55:33 -0800 Subject: [PATCH] Predict fix (#281) * fix regressor matrix extraction * unit test for cases with mixed pos® regressors --- orbit/models/dlt.py | 2 +- orbit/models/lgt.py | 4 +++- tests/orbit/models/test_dlt.py | 29 +++++++++++++++++++++++++++++ tests/orbit/models/test_lgt.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 2 deletions(-) diff --git a/orbit/models/dlt.py b/orbit/models/dlt.py index 89e630ff..0213f199 100644 --- a/orbit/models/dlt.py +++ b/orbit/models/dlt.py @@ -193,7 +193,7 @@ def _predict(self, posterior_estimates, df=None, include_error=False, decompose= # calculate regression component if self.regressor_col is not None and len(self.regressor_col) > 0: regressor_beta = regressor_beta.t() - regressor_matrix = df[self.regressor_col].values + regressor_matrix = df[self._regressor_col].values regressor_torch = torch.from_numpy(regressor_matrix).double() regressor_component = torch.matmul(regressor_torch, regressor_beta) regressor_component = regressor_component.t() diff --git a/orbit/models/lgt.py b/orbit/models/lgt.py index a25d8939..7f40dbb1 100644 --- a/orbit/models/lgt.py +++ b/orbit/models/lgt.py @@ -224,6 +224,8 @@ def _set_static_regression_attributes(self): self._regular_regressor_beta_prior.append(self._regressor_beta_prior[index]) self._regular_regressor_sigma_prior.append(self._regressor_sigma_prior[index]) + self._regressor_col = self._positive_regressor_col + self._regular_regressor_col + def _set_with_mcmc(self): estimator_type = self.estimator_type # set `_with_mcmc` attribute based on estimator type @@ -521,7 +523,7 @@ def _predict(self, posterior_estimates, df, include_error=False, decompose=False # calculate regression component if self.regressor_col is not None and len(self.regressor_col) > 0: regressor_beta = regressor_beta.t() - regressor_matrix = df[self.regressor_col].values + regressor_matrix = df[self._regressor_col].values regressor_torch = torch.from_numpy(regressor_matrix).double() regressor_component = torch.matmul(regressor_torch, regressor_beta) regressor_component = regressor_component.t() diff --git a/tests/orbit/models/test_dlt.py b/tests/orbit/models/test_dlt.py index 53e21010..b1322c4a 100644 --- a/tests/orbit/models/test_dlt.py +++ b/tests/orbit/models/test_dlt.py @@ -1,4 +1,6 @@ import pytest +import numpy as np + from orbit.models.dlt import BaseDLT, DLTFull, DLTAggregated, DLTMAP from orbit.estimators.stan_estimator import StanEstimatorMCMC, StanEstimatorVI, StanEstimatorMAP @@ -251,3 +253,30 @@ def test_dlt_predict_all_positive_reg(iclaims_training_data): predicted_df = dlt.predict(df, decompose=True) assert any(predicted_df['regression'].values) + +def test_dlt_predict_mixed_regular_positive(iclaims_training_data): + df = iclaims_training_data + + dlt = DLTMAP( + response_col='claims', + date_col='week', + regressor_col=['trend.unemploy', 'trend.filling', 'trend.job'], + regressor_sign=['=', '+', '='], + seasonality=52, + seed=8888, + ) + dlt.fit(df) + predicted_df = dlt.predict(df) + + dlt_new = DLTMAP( + response_col='claims', + date_col='week', + regressor_col=['trend.unemploy', 'trend.job', 'trend.filling'], + regressor_sign=['=', '=', '+'], + seasonality=52, + seed=8888, + ) + dlt_new.fit(df) + predicted_df_new = dlt_new.predict(df) + + assert np.allclose(predicted_df['prediction'].values, predicted_df_new['prediction'].values) diff --git a/tests/orbit/models/test_lgt.py b/tests/orbit/models/test_lgt.py index 2f21bc94..b47e9316 100644 --- a/tests/orbit/models/test_lgt.py +++ b/tests/orbit/models/test_lgt.py @@ -1,4 +1,5 @@ import pytest +import numpy as np from orbit.estimators.pyro_estimator import PyroEstimator, PyroEstimatorVI, PyroEstimatorMAP from orbit.estimators.stan_estimator import StanEstimator, StanEstimatorMCMC, StanEstimatorVI, StanEstimatorMAP @@ -313,6 +314,33 @@ def test_lgt_predict_all_positive_reg(iclaims_training_data): assert any(predicted_df['regression'].values) +def test_lgt_predict_mixed_regular_positive(iclaims_training_data): + df = iclaims_training_data + + lgt = LGTMAP( + response_col='claims', + date_col='week', + regressor_col=['trend.unemploy', 'trend.filling', 'trend.job'], + regressor_sign=['=', '+', '='], + seasonality=52, + seed=8888, + ) + lgt.fit(df) + predicted_df = lgt.predict(df) + + lgt_new = LGTMAP( + response_col='claims', + date_col='week', + regressor_col=['trend.unemploy', 'trend.job', 'trend.filling'], + regressor_sign=['=', '=', '+'], + seasonality=52, + seed=8888, + ) + lgt_new.fit(df) + predicted_df_new = lgt_new.predict(df) + + assert np.allclose(predicted_df['prediction'].values, predicted_df_new['prediction'].values) + @pytest.mark.parametrize("prediction_percentiles", [None, [5, 10, 95]]) def test_prediction_percentiles(iclaims_training_data, prediction_percentiles):