diff --git a/.travis.yml b/.travis.yml index 84f0f486..bc726497 100644 --- a/.travis.yml +++ b/.travis.yml @@ -27,6 +27,17 @@ jobs: # environment variable in Travis see below. - fossa init - fossa analyze + - stage: deploy + script: skip + deploy: + provider: pypi + edge: true + username: "__token__" + password: $PYPI_API_TOKEN + on: + tags: true + skip_existing: true + distributions: "sdist bdist_wheel" after_success: - - cd orbit && fossa test \ No newline at end of file + - fossa test \ No newline at end of file diff --git a/README.rst b/README.rst index 890e37d7..ed1c5e64 100644 --- a/README.rst +++ b/README.rst @@ -1,15 +1,23 @@ -.. image:: docs/img/orbit-icon-small.png +.. image:: docs/img/orbit-banner.png ------------------------------------------- -**Disclaimer: Orbit requires PyStan as a system dependency. PyStan is -licensed under** `GPLv3 `__ **, -which is a free, copyleft license for software.** +|pypi| |travis| |downloads| -Orbit is a Python package for time series modeling and inference -using Bayesian sampling methods for model estimation. It provides a +Disclaimer +========== + +This project + +- is stable and being incubated for long-term support. It may contain new experimental code, for which APIs are subject to change. +- requires PyStan as a system dependency. PyStan is licensed under `GPLv3 `__, which is a free, copyleft license for software. + +Orbit: A Python package for Bayesian forecasting models +==================== + +Orbit is a Python package for Bayesian forecasting models developed under object-oriented design. It provides a familiar and intuitive initialize-fit-predict interface for working with -time series tasks, while utilizing probabilistic modeling under +time series tasks, while utilizing probabilistic modeling api under the hood. The initial release supports concrete implementation for the following @@ -134,3 +142,16 @@ Related projects - `Pyro `__ - `Stan `__ - `Rlgt `__ + + +.. |pypi| image:: https://badge.fury.io/py/orbit-ml.svg + :target: https://badge.fury.io/py/orbit-ml + :alt: pypi + +.. |travis| image:: https://travis-ci.com/uber/orbit.svg?branch=master + :target: https://travis-ci.com/uber/orbit + :alt: travis + +.. |downloads| image:: https://static.pepy.tech/personalized-badge/orbit-ml?period=month&units=international_system&left_color=blue&right_color=grey&left_text=Downloads + :target: https://pepy.tech/project/orbit-ml + :alt: downloads diff --git a/docs/_static/css/orbit.css b/docs/_static/css/orbit.css new file mode 100644 index 00000000..d6b524a3 --- /dev/null +++ b/docs/_static/css/orbit.css @@ -0,0 +1,29 @@ +@import url("theme.css"); + +.wy-side-nav-search { + background-color: #276EF1; +} + +.wy-side-nav-search a { + margin: 0 +} + +.wy-side-nav-search > div.version { + color: #ffffff; +} + +.wy-nav-top { + background: #949CE3; +} + +.wy-menu-vertical li.on a, .wy-menu-vertical li.current>a { + background: #c4cfd4; +} + +.wy-side-nav-search input[type=text] { + border-color: #313131; +} + +.wy-side-nav-search>a img.logo, .wy-side-nav-search .wy-dropdown>a img.logo { + max-width: 40%; +} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index d5306483..fb264e55 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,6 +12,7 @@ import os import sys +import sphinx_rtd_theme sys.path.insert(0, os.path.abspath('..')) import matplotlib import orbit @@ -38,6 +39,7 @@ # 'sphinx.ext.autosummary', # 'sphinx.ext.doctest', 'sphinx.ext.mathjax', + 'sphinx.ext.githubpages', # 'sphinx.ext.viewcode', 'sphinx.ext.napoleon', 'nbsphinx', @@ -62,6 +64,7 @@ # a list of builtin themes. # html_theme = 'sphinx_rtd_theme' +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] html_context = { 'display_github': True, @@ -70,10 +73,26 @@ 'github_version': 'master/docs/', } +# logo +html_logo = 'img/orbit-logo-black.png' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + # 'logo_only': False, + 'navigation_depth': 3, +} + # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] +html_style = 'css/orbit.css' + + +html_favicon = 'img/favicon/favicon.ico' napoleon_numpy_docstring = True napoleon_include_init_with_doc = True diff --git a/docs/img/favicon/favicon.ico b/docs/img/favicon/favicon.ico new file mode 100644 index 00000000..0ab8f362 Binary files /dev/null and b/docs/img/favicon/favicon.ico differ diff --git a/docs/img/orbit-banner.png b/docs/img/orbit-banner.png new file mode 100644 index 00000000..e1615e58 Binary files /dev/null and b/docs/img/orbit-banner.png differ diff --git a/docs/img/orbit-icon-raw.png b/docs/img/orbit-icon-raw.png new file mode 100644 index 00000000..bdb953aa Binary files /dev/null and b/docs/img/orbit-icon-raw.png differ diff --git a/docs/img/orbit-icon-small.png b/docs/img/orbit-icon-small.png deleted file mode 100644 index 5e968832..00000000 Binary files a/docs/img/orbit-icon-small.png and /dev/null differ diff --git a/docs/img/orbit-icon.png b/docs/img/orbit-icon.png deleted file mode 100644 index ac885c48..00000000 Binary files a/docs/img/orbit-icon.png and /dev/null differ diff --git a/docs/img/orbit-logo-black.png b/docs/img/orbit-logo-black.png new file mode 100644 index 00000000..ec446308 Binary files /dev/null and b/docs/img/orbit-logo-black.png differ diff --git a/docs/img/orbit-logo.png b/docs/img/orbit-logo.png new file mode 100644 index 00000000..38d2f5d7 Binary files /dev/null and b/docs/img/orbit-logo.png differ diff --git a/docs/tutorials/decompose_prediction.ipynb b/docs/tutorials/decompose_prediction.ipynb index ff8b0f8a..99514486 100644 --- a/docs/tutorials/decompose_prediction.ipynb +++ b/docs/tutorials/decompose_prediction.ipynb @@ -25,6 +25,7 @@ }, "outputs": [], "source": [ + "%matplotlib inline\n", "import pandas as pd\n", "import numpy as np\n", "from orbit.models.dlt import DLTMAP, DLTFull\n", @@ -197,7 +198,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.6.8" }, "toc": { "base_numbering": 1, diff --git a/docs/tutorials/dlt.ipynb b/docs/tutorials/dlt.ipynb index 5af8e142..8eeb92ba 100644 --- a/docs/tutorials/dlt.ipynb +++ b/docs/tutorials/dlt.ipynb @@ -72,6 +72,7 @@ }, "outputs": [], "source": [ + "%matplotlib inline\n", "import pandas as pd\n", "import numpy as np\n", "from orbit.models.dlt import DLTMAP, DLTFull, DLTAggregated\n", @@ -625,7 +626,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.6.8" }, "toc": { "base_numbering": 1, diff --git a/docs/tutorials/lgt.ipynb b/docs/tutorials/lgt.ipynb index 4a23322d..bb6e85dd 100644 --- a/docs/tutorials/lgt.ipynb +++ b/docs/tutorials/lgt.ipynb @@ -60,6 +60,7 @@ }, "outputs": [], "source": [ + "%matplotlib inline\n", "import pandas as pd\n", "import numpy as np\n", "from orbit.models.lgt import LGTMAP, LGTAggregated, LGTFull\n", @@ -162,8 +163,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 227 ms, sys: 14 ms, total: 241 ms\n", - "Wall time: 527 ms\n" + "CPU times: user 223 ms, sys: 9.22 ms, total: 232 ms\n", + "Wall time: 317 ms\n" ] } ], @@ -265,8 +266,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 74.3 ms, sys: 64.2 ms, total: 138 ms\n", - "Wall time: 8.14 s\n" + "CPU times: user 70.2 ms, sys: 67 ms, total: 137 ms\n", + "Wall time: 8.44 s\n" ] } ], @@ -320,7 +321,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.6.8" }, "toc": { "base_numbering": 1, diff --git a/docs/tutorials/pyro_basic.ipynb b/docs/tutorials/pyro_basic.ipynb index ce5c5b27..b2131d77 100644 --- a/docs/tutorials/pyro_basic.ipynb +++ b/docs/tutorials/pyro_basic.ipynb @@ -29,6 +29,7 @@ }, "outputs": [], "source": [ + "%matplotlib inline\n", "import pandas as pd\n", "import numpy as np\n", "from orbit.models.lgt import LGTMAP, LGTAggregated, LGTFull\n", @@ -303,7 +304,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.6.8" }, "toc": { "base_numbering": 1, diff --git a/docs/tutorials/quick_start.ipynb b/docs/tutorials/quick_start.ipynb index 958239f1..3e3aa3b9 100644 --- a/docs/tutorials/quick_start.ipynb +++ b/docs/tutorials/quick_start.ipynb @@ -18,6 +18,7 @@ }, "outputs": [], "source": [ + "%matplotlib inline\n", "import pandas as pd\n", "import numpy as np\n", "import warnings\n", @@ -210,7 +211,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.6.8" }, "toc": { "base_numbering": 1, diff --git a/examples/DLT_Example.ipynb b/examples/DLT_Example.ipynb index 365ef49a..76ed0a50 100644 --- a/examples/DLT_Example.ipynb +++ b/examples/DLT_Example.ipynb @@ -210,7 +210,7 @@ } ], "source": [ - "plot_predicted_data(training_actual_df=df, predicted_df=predicted_df, \n", + "_ = plot_predicted_data(training_actual_df=df, predicted_df=predicted_df, \n", " date_col=dlt.date_col, actual_col=dlt.response_col)" ] }, @@ -298,7 +298,7 @@ } ], "source": [ - "plot_predicted_data(training_actual_df=df, predicted_df=predicted_df_dlt_log, \n", + "_ = plot_predicted_data(training_actual_df=df, predicted_df=predicted_df_dlt_log, \n", " date_col=dlt_log.date_col, actual_col=dlt_log.response_col)" ] }, @@ -386,7 +386,7 @@ } ], "source": [ - "plot_predicted_data(training_actual_df=df, predicted_df=predicted_df_dlt_logit, \n", + "_ = plot_predicted_data(training_actual_df=df, predicted_df=predicted_df_dlt_logit, \n", " date_col=dlt_log.date_col, actual_col=dlt_log.response_col)" ] }, @@ -467,7 +467,7 @@ } ], "source": [ - "plot_predicted_data(training_actual_df=df, predicted_df=predicted_df_dlt_logit, \n", + "_ = plot_predicted_data(training_actual_df=df, predicted_df=predicted_df_dlt_logit, \n", " date_col=dlt_log.date_col, actual_col=dlt_log.response_col)" ] } diff --git a/examples/Daily_Forecast_Example.ipynb b/examples/Daily_Forecast_Example.ipynb index 6759e209..2a6e485d 100644 --- a/examples/Daily_Forecast_Example.ipynb +++ b/examples/Daily_Forecast_Example.ipynb @@ -243,7 +243,7 @@ } ], "source": [ - "plot_predicted_data(training_actual_df=df[-90:], predicted_df=predicted_df[-90:], \n", + "_ = plot_predicted_data(training_actual_df=df[-90:], predicted_df=predicted_df[-90:], \n", " test_actual_df=test_df, date_col=dlt.date_col,\n", " actual_col='sales')" ] diff --git a/examples/LGT_Example.ipynb b/examples/LGT_Example.ipynb index 22b73bd8..75bfa594 100644 --- a/examples/LGT_Example.ipynb +++ b/examples/LGT_Example.ipynb @@ -419,7 +419,7 @@ } ], "source": [ - "plot_predicted_data(training_actual_df=train_df, predicted_df=predicted_df, \n", + "_ = plot_predicted_data(training_actual_df=train_df, predicted_df=predicted_df, \n", " date_col=date_col, actual_col=response_col, \n", " test_actual_df=test_df)" ] @@ -610,7 +610,7 @@ } ], "source": [ - "plot_predicted_data(training_actual_df=train_df, predicted_df=predicted_df, \n", + "_ = plot_predicted_data(training_actual_df=train_df, predicted_df=predicted_df, \n", " date_col=lgt.date_col, actual_col=lgt.response_col, \n", " test_actual_df=test_df)" ] @@ -884,7 +884,7 @@ } ], "source": [ - "plot_predicted_components(predicted_df=predicted_df, date_col=date_col)" + "_ = plot_predicted_components(predicted_df=predicted_df, date_col=date_col)" ] } ], diff --git a/examples/LGT_Pyro_Example.ipynb b/examples/LGT_Pyro_Example.ipynb index f6ac8327..5708e18d 100644 --- a/examples/LGT_Pyro_Example.ipynb +++ b/examples/LGT_Pyro_Example.ipynb @@ -196,7 +196,7 @@ ], "source": [ "predicted_df = lgt_map.predict(df=test_df)\n", - "plot_predicted_data(training_actual_df=train_df, predicted_df=predicted_df, \n", + "_ = plot_predicted_data(training_actual_df=train_df, predicted_df=predicted_df, \n", " date_col=lgt_map.date_col, actual_col=lgt_map.response_col, \n", " test_actual_df=test_df)" ] @@ -314,7 +314,7 @@ } ], "source": [ - "plot_predicted_data(training_actual_df=train_df, predicted_df=predicted_df, \n", + "_ = plot_predicted_data(training_actual_df=train_df, predicted_df=predicted_df, \n", " date_col=lgt_vi.date_col, actual_col=lgt_vi.response_col, \n", " test_actual_df=test_df)" ] @@ -322,9 +322,9 @@ ], "metadata": { "kernelspec": { - "display_name": "orbit", + "display_name": "Python 3", "language": "python", - "name": "orbit" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -336,7 +336,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.7.7" }, "toc": { "base_numbering": 1, diff --git a/orbit/__init__.py b/orbit/__init__.py index e6100845..112006d1 100644 --- a/orbit/__init__.py +++ b/orbit/__init__.py @@ -1,3 +1,3 @@ name = 'orbit' -__version__ = '1.0.5' +__version__ = '1.0.10' diff --git a/orbit/diagnostics/plot.py b/orbit/diagnostics/plot.py index d5a980fb..0787999b 100644 --- a/orbit/diagnostics/plot.py +++ b/orbit/diagnostics/plot.py @@ -113,8 +113,11 @@ def plot_predicted_data(training_actual_df, predicted_df, date_col, actual_col, if is_visible: plt.show() + return ax + +def plot_predicted_components(predicted_df, date_col, prediction_percentiles=None, plot_components=None, + title="", figsize=None, path=None): -def plot_predicted_components(predicted_df, date_col, prediction_percentiles=None, title="", figsize=None, path=None): """ Plot predicted componenets with the data frame of decomposed prediction where components has been pre-defined as `trend`, `seasonality` and `regression`. Parameters @@ -127,6 +130,9 @@ def plot_predicted_components(predicted_df, date_col, prediction_percentiles=Non prediction_percentiles: list a list should consist exact two elements which will be used to plot as lower and upper bound of confidence interval + plot_components: list + a list of strings to show the label of components to be plotted; by default, it uses values in + `orbit.constants.constants.PredictedComponents`. title: str title of the plot figsize: tuple @@ -140,9 +146,11 @@ def plot_predicted_components(predicted_df, date_col, prediction_percentiles=Non _predicted_df=predicted_df.copy() _predicted_df[date_col] = pd.to_datetime(_predicted_df[date_col]) - plot_components = [PredictedComponents.TREND.value, - PredictedComponents.SEASONALITY.value, - PredictedComponents.REGRESSION.value] + if plot_components is None: + plot_components = [PredictedComponents.TREND.value, + PredictedComponents.SEASONALITY.value, + PredictedComponents.REGRESSION.value] + plot_components = [p for p in plot_components if p in _predicted_df.columns.tolist()] n_panels = len(plot_components) if not figsize: @@ -174,6 +182,8 @@ def plot_predicted_components(predicted_df, date_col, prediction_percentiles=Non if path: plt.savefig(path) + return axes + def metric_horizon_barplot(df, model_col='model', pred_horizon_col='pred_horizon', metric_col='smape', bar_width=0.1, path=None): @@ -211,9 +221,10 @@ def metric_horizon_barplot(df, model_col='model', pred_horizon_col='pred_horizon plt.savefig(path) + def plot_posterior_params(mod, kind='density', n_bins=20, ci_level=.95, - pair_type='scatter', figsize=None, path=None, - incl_trend_params=False, incl_smooth_params=False): + pair_type='scatter', figsize=None, path=None, + incl_trend_params=False, incl_smooth_params=False): """ Data Viz for posterior samples Params @@ -277,7 +288,7 @@ def _density_plot(posterior_samples, n_bins=20, ci_level=.95, figsize=None): mean = np.mean(samples) median = np.median(samples) cred_min, cred_max = np.percentile(samples, 100 * (1 - ci_level)/2), \ - np.percentile(samples, 100 * (1 + ci_level)/2) + np.percentile(samples, 100 * (1 + ci_level)/2) sns.distplot(samples, bins=n_bins, kde_kws={'shade':True}, ax=axes[i], norm_hist=False) # sns.kdeplot(samples, shade=True, ax=axes[i]) @@ -295,7 +306,7 @@ def _density_plot(posterior_samples, n_bins=20, ci_level=.95, figsize=None): plt.suptitle('Histogram and Density of Posterior Samples') plt.tight_layout(rect=[0, 0.03, 1, 0.95]) - return fig + return axes def _trace_plot(posterior_samples, ci_level=.95, figsize=None): @@ -325,7 +336,7 @@ def _trace_plot(posterior_samples, ci_level=.95, figsize=None): plt.xlabel('draw') plt.tight_layout(rect=[0, 0.03, 1, 0.95]) - return fig + return axes def _pair_plot(posterior_samples, pair_type='scatter', n_bins=20): samples_df = pd.DataFrame({key: posterior_samples[key].flatten() for key in params_}) @@ -337,13 +348,13 @@ def _pair_plot(posterior_samples, pair_type='scatter', n_bins=20): return fig if kind == 'density': - fig = _density_plot(posterior_samples, n_bins=n_bins, ci_level=ci_level, figsize=figsize) + axes = _density_plot(posterior_samples, n_bins=n_bins, ci_level=ci_level, figsize=figsize) elif kind == 'trace': - fig = _trace_plot(posterior_samples, ci_level=ci_level, figsize=figsize) + axes = _trace_plot(posterior_samples, ci_level=ci_level, figsize=figsize) elif kind == 'pair': - fig = _pair_plot(posterior_samples, pair_type=pair_type, n_bins=n_bins) + axes = _pair_plot(posterior_samples, pair_type=pair_type, n_bins=n_bins) if path: plt.savefig(path) - return fig + return axes diff --git a/orbit/models/dlt.py b/orbit/models/dlt.py index 89e630ff..0213f199 100644 --- a/orbit/models/dlt.py +++ b/orbit/models/dlt.py @@ -193,7 +193,7 @@ def _predict(self, posterior_estimates, df=None, include_error=False, decompose= # calculate regression component if self.regressor_col is not None and len(self.regressor_col) > 0: regressor_beta = regressor_beta.t() - regressor_matrix = df[self.regressor_col].values + regressor_matrix = df[self._regressor_col].values regressor_torch = torch.from_numpy(regressor_matrix).double() regressor_component = torch.matmul(regressor_torch, regressor_beta) regressor_component = regressor_component.t() diff --git a/orbit/models/lgt.py b/orbit/models/lgt.py index a25d8939..05677497 100644 --- a/orbit/models/lgt.py +++ b/orbit/models/lgt.py @@ -121,6 +121,8 @@ def __init__(self, response_col='y', date_col='ds', regressor_col=None, self._regular_regressor_col = list() self._regular_regressor_beta_prior = list() self._regular_regressor_sigma_prior = list() + # internal regressor_col to maintain column ordering on prediction + self._regressor_col = list() # depends on seasonality length self._init_values = None @@ -224,6 +226,8 @@ def _set_static_regression_attributes(self): self._regular_regressor_beta_prior.append(self._regressor_beta_prior[index]) self._regular_regressor_sigma_prior.append(self._regressor_sigma_prior[index]) + self._regressor_col = self._positive_regressor_col + self._regular_regressor_col + def _set_with_mcmc(self): estimator_type = self.estimator_type # set `_with_mcmc` attribute based on estimator type @@ -521,7 +525,7 @@ def _predict(self, posterior_estimates, df, include_error=False, decompose=False # calculate regression component if self.regressor_col is not None and len(self.regressor_col) > 0: regressor_beta = regressor_beta.t() - regressor_matrix = df[self.regressor_col].values + regressor_matrix = df[self._regressor_col].values regressor_torch = torch.from_numpy(regressor_matrix).double() regressor_component = torch.matmul(regressor_torch, regressor_beta) regressor_component = regressor_component.t() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..3886c687 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[metadata] +version = attr: orbit.__version__ \ No newline at end of file diff --git a/setup.py b/setup.py index 441fce9d..4ac95f8e 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ # dist.Distribution().fetch_build_eggs(['cython']) -VERSION = '1.0.5' +# VERSION = '1.0.6' DESCRIPTION = "Orbit is a package for bayesian time series modeling and inference." AUTHOR = '''Edwin Ng , Steve Yang , Huigang Chen , Zhishi Wang ''' @@ -60,7 +60,7 @@ def run_tests(self): name='orbit-ml', packages=find_packages(), url='https://uber.github.io/orbit/', - version=VERSION, + # version=VERSION, # being maintained by source module zip_safe=False, classifiers=[ 'Development Status :: 3 - Alpha', diff --git a/tests/orbit/models/test_dlt.py b/tests/orbit/models/test_dlt.py index 53e21010..8301b5d7 100644 --- a/tests/orbit/models/test_dlt.py +++ b/tests/orbit/models/test_dlt.py @@ -1,4 +1,6 @@ import pytest +import numpy as np + from orbit.models.dlt import BaseDLT, DLTFull, DLTAggregated, DLTMAP from orbit.estimators.stan_estimator import StanEstimatorMCMC, StanEstimatorVI, StanEstimatorMAP @@ -235,6 +237,7 @@ def test_dlt_map_global_trend(synthetic_data, global_trend_option): assert predict_df.shape == expected_shape assert predict_df.columns.tolist() == expected_columns + def test_dlt_predict_all_positive_reg(iclaims_training_data): df = iclaims_training_data @@ -251,3 +254,31 @@ def test_dlt_predict_all_positive_reg(iclaims_training_data): predicted_df = dlt.predict(df, decompose=True) assert any(predicted_df['regression'].values) + + +def test_dlt_predict_mixed_regular_positive(iclaims_training_data): + df = iclaims_training_data + + dlt = DLTMAP( + response_col='claims', + date_col='week', + regressor_col=['trend.unemploy', 'trend.filling', 'trend.job'], + regressor_sign=['=', '+', '='], + seasonality=52, + seed=8888, + ) + dlt.fit(df) + predicted_df = dlt.predict(df) + + dlt_new = DLTMAP( + response_col='claims', + date_col='week', + regressor_col=['trend.unemploy', 'trend.job', 'trend.filling'], + regressor_sign=['=', '=', '+'], + seasonality=52, + seed=8888, + ) + dlt_new.fit(df) + predicted_df_new = dlt_new.predict(df) + + assert np.allclose(predicted_df['prediction'].values, predicted_df_new['prediction'].values) diff --git a/tests/orbit/models/test_lgt.py b/tests/orbit/models/test_lgt.py index 2f21bc94..4fda3e65 100644 --- a/tests/orbit/models/test_lgt.py +++ b/tests/orbit/models/test_lgt.py @@ -1,10 +1,12 @@ import pytest +import numpy as np from orbit.estimators.pyro_estimator import PyroEstimator, PyroEstimatorVI, PyroEstimatorMAP from orbit.estimators.stan_estimator import StanEstimator, StanEstimatorMCMC, StanEstimatorVI, StanEstimatorMAP from orbit.models.lgt import BaseLGT, LGTFull, LGTAggregated, LGTMAP from orbit.constants.constants import PredictedComponents + def test_base_lgt_init(): lgt = BaseLGT() @@ -314,6 +316,34 @@ def test_lgt_predict_all_positive_reg(iclaims_training_data): assert any(predicted_df['regression'].values) +def test_lgt_predict_mixed_regular_positive(iclaims_training_data): + df = iclaims_training_data + + lgt = LGTMAP( + response_col='claims', + date_col='week', + regressor_col=['trend.unemploy', 'trend.filling', 'trend.job'], + regressor_sign=['=', '+', '='], + seasonality=52, + seed=8888, + ) + lgt.fit(df) + predicted_df = lgt.predict(df) + + lgt_new = LGTMAP( + response_col='claims', + date_col='week', + regressor_col=['trend.unemploy', 'trend.job', 'trend.filling'], + regressor_sign=['=', '=', '+'], + seasonality=52, + seed=8888, + ) + lgt_new.fit(df) + predicted_df_new = lgt_new.predict(df) + + assert np.allclose(predicted_df['prediction'].values, predicted_df_new['prediction'].values) + + @pytest.mark.parametrize("prediction_percentiles", [None, [5, 10, 95]]) def test_prediction_percentiles(iclaims_training_data, prediction_percentiles): df = iclaims_training_data