From 2b4aba962d4aef978fa1677e412186677e343a17 Mon Sep 17 00:00:00 2001 From: Edwin Ng Date: Wed, 27 Jan 2021 19:55:07 -0800 Subject: [PATCH] Feat pickle fix (#342) * first pickle fix attempt * better attempt by using initializer creating an initializer class to tackle problem of #340 fixed #40 * ETS Initializer add ets initializer and change some wording from stan to generic * change wording from stan to generic [minor] minor wording --- orbit/constants/constants.py | 4 ++-- orbit/constants/dlt.py | 16 ++++++++++--- orbit/constants/ets.py | 14 +++++++---- orbit/constants/lgt.py | 16 ++++++++++--- orbit/diagnostics/plot.py | 4 ++-- orbit/estimators/base_estimator.py | 8 +++---- orbit/estimators/pyro_estimator.py | 8 +++---- orbit/initializer/__init__.py | 0 orbit/initializer/dlt.py | 33 ++++++++++++++++++++++++++ orbit/initializer/ets.py | 16 +++++++++++++ orbit/initializer/lgt.py | 33 ++++++++++++++++++++++++++ orbit/models/dlt.py | 37 +++++++----------------------- orbit/models/ets.py | 23 ++++++------------- orbit/models/lgt.py | 36 +++++++---------------------- 14 files changed, 153 insertions(+), 95 deletions(-) create mode 100644 orbit/initializer/__init__.py create mode 100644 orbit/initializer/dlt.py create mode 100644 orbit/initializer/ets.py create mode 100644 orbit/initializer/lgt.py diff --git a/orbit/constants/constants.py b/orbit/constants/constants.py index 6e76ec4b..6e858ef7 100644 --- a/orbit/constants/constants.py +++ b/orbit/constants/constants.py @@ -28,8 +28,8 @@ class EstimatorOptionsMapper(Enum): set of options) """ ENGINE_TO_SAMPLE = { - 'stan': ['map','vi','mcmc'], - 'pyro': ['map','vi'] + 'stan': ['map', 'vi', 'mcmc'], + 'pyro': ['map', 'vi'] } SAMPLE_TO_PREDICT = { 'map': ['map'], diff --git a/orbit/constants/dlt.py b/orbit/constants/dlt.py index 807a1258..2ebaa0cb 100644 --- a/orbit/constants/dlt.py +++ b/orbit/constants/dlt.py @@ -53,7 +53,7 @@ class GlobalTrendOption(Enum): class BaseSamplingParameters(Enum): """ - The stan output sampling parameters related with DLT base model. + base parameters in posteriors sampling """ # ---------- Common Local Trend ---------- # LOCAL_TREND_LEVELS = 'l' @@ -75,7 +75,7 @@ class GlobalTrendSamplingParameters(Enum): class SeasonalitySamplingParameters(Enum): """ - The stan output sampling parameters related with seasonality component. + seasonality component related parameters in posteriors sampling """ SEASONALITY_LEVELS = 's' SEASONALITY_SMOOTHING_FACTOR = 'sea_sm' @@ -83,11 +83,21 @@ class SeasonalitySamplingParameters(Enum): class RegressionSamplingParameters(Enum): """ - The stan output sampling parameters related with regression component. + regression component related parameters in posteriors sampling """ REGRESSION_COEFFICIENTS = 'beta' +class LatentSamplingParameters(Enum): + """ + latent variables to be sampled + """ + REGRESSION_POSITIVE_COEFFICIENTS = 'pr_beta' + REGRESSION_NEGATIVE_COEFFICIENTS = 'nr_beta' + REGRESSION_REGULAR_COEFFICIENTS = 'rr_beta' + INITIAL_SEASONALITY = 'init_sea' + + class RegressionPenalty(Enum): fixed_ridge = 0 lasso = 1 diff --git a/orbit/constants/ets.py b/orbit/constants/ets.py index 6a67a40d..08529c3b 100644 --- a/orbit/constants/ets.py +++ b/orbit/constants/ets.py @@ -3,9 +3,8 @@ class DataInputMapper(Enum): """ - mapping from object input to stan file + mapping from object input to sampler """ - # All of the following have default defined in DEFAULT_SLGT_FIT_ATTRIBUTES # ---------- Data Input ---------- # # observation related _NUM_OF_OBSERVATIONS = 'NUM_OF_OBS' @@ -21,7 +20,7 @@ class DataInputMapper(Enum): class BaseSamplingParameters(Enum): """ - The stan output sampling parameters related with LGT base model. + base parameters in posteriors sampling """ # ---------- Common Local Trend ---------- # LOCAL_TREND_LEVELS = 'l' @@ -32,7 +31,14 @@ class BaseSamplingParameters(Enum): class SeasonalitySamplingParameters(Enum): """ - The stan output sampling parameters related with seasonality component. + seasonality component related parameters in posteriors sampling """ SEASONALITY_LEVELS = 's' SEASONALITY_SMOOTHING_FACTOR = 'sea_sm' + + +class LatentSamplingParameters(Enum): + """ + latent variables to be sampled + """ + INITIAL_SEASONALITY = 'init_sea' diff --git a/orbit/constants/lgt.py b/orbit/constants/lgt.py index 99e84c44..0ea9d184 100644 --- a/orbit/constants/lgt.py +++ b/orbit/constants/lgt.py @@ -41,7 +41,7 @@ class DataInputMapper(Enum): class BaseSamplingParameters(Enum): """ - The stan output sampling parameters related with LGT base model. + base parameters in posteriors sampling """ # ---------- Common Local Trend ---------- # LOCAL_TREND_LEVELS = 'l' @@ -60,7 +60,7 @@ class BaseSamplingParameters(Enum): class SeasonalitySamplingParameters(Enum): """ - The stan output sampling parameters related with seasonality component. + seasonality component related parameters in posteriors sampling """ SEASONALITY_LEVELS = 's' SEASONALITY_SMOOTHING_FACTOR = 'sea_sm' @@ -68,11 +68,21 @@ class SeasonalitySamplingParameters(Enum): class RegressionSamplingParameters(Enum): """ - The stan output sampling parameters related with regression component. + regression component related parameters in posteriors sampling """ REGRESSION_COEFFICIENTS = 'beta' +class LatentSamplingParameters(Enum): + """ + latent variables to be sampled + """ + REGRESSION_POSITIVE_COEFFICIENTS = 'pr_beta' + REGRESSION_NEGATIVE_COEFFICIENTS = 'nr_beta' + REGRESSION_REGULAR_COEFFICIENTS = 'rr_beta' + INITIAL_SEASONALITY = 'init_sea' + + class RegressionPenalty(Enum): fixed_ridge = 0 lasso = 1 diff --git a/orbit/diagnostics/plot.py b/orbit/diagnostics/plot.py index c4b64523..45e17876 100644 --- a/orbit/diagnostics/plot.py +++ b/orbit/diagnostics/plot.py @@ -49,7 +49,7 @@ def plot_predicted_data(training_actual_df, predicted_df, date_col, actual_col, figsize pass through to `matplotlib.pyplot.figure()` path: str path to save the figure - fontzise: int + fontsize: int fontsize of the title Returns ------- @@ -146,7 +146,7 @@ def plot_predicted_components(predicted_df, date_col, prediction_percentiles=Non figsize pass through to `matplotlib.pyplot.figure()` path: str; optional path to save the figure - fontzise: int; optional + fontsize: int; optional fontsize of the title is_visible: boolean whether we want to show the plot. If called from unittest, is_visible might = False. diff --git a/orbit/estimators/base_estimator.py b/orbit/estimators/base_estimator.py index 702ace8a..812b6972 100644 --- a/orbit/estimators/base_estimator.py +++ b/orbit/estimators/base_estimator.py @@ -1,4 +1,4 @@ -from abc import ABCMeta, abstractmethod +from abc import abstractmethod import numpy as np @@ -10,7 +10,7 @@ class BaseEstimator(object): seed : int seed number for initial random values verbose : bool - If True, output all stan diagnostics messages + If True, output all diagnostics messages from estimators """ def __init__(self, seed=8888, verbose=False): @@ -27,11 +27,11 @@ def fit(self, model_name, model_param_names, data_input, init_values=None): Parameters ---------- model_name : str - name of stan model + name of model - used in mapping the right sampling file (stan/pyro/...) model_param_names : list list of strings of model parameters names to extract data_input : dict - key-value pairs of data input as required by definition in stan model + key-value pairs of data input as required by definition in samplers (stan/pyro/...) init_values : float or np.array initial sampler value. If None, 'random' is used diff --git a/orbit/estimators/pyro_estimator.py b/orbit/estimators/pyro_estimator.py index 75ee1f47..061d9bad 100644 --- a/orbit/estimators/pyro_estimator.py +++ b/orbit/estimators/pyro_estimator.py @@ -111,10 +111,10 @@ def fit(self, model_name, model_param_names, data_input, init_values=None): # make sure that model param names are a subset of stan extract keys invalid_model_param = set(model_param_names) - set(list(extract.keys())) if invalid_model_param: - raise EstimatorException("Stan model definition does not contain required parameters") + raise EstimatorException("Pyro model definition does not contain required parameters") # `stan.optimizing` automatically returns all defined parameters - # filter out unecessary keys + # filter out unnecessary keys extract = {param: extract[param] for param in model_param_names} return extract @@ -165,10 +165,10 @@ def fit(self, model_name, model_param_names, data_input, init_values=None): # make sure that model param names are a subset of stan extract keys invalid_model_param = set(model_param_names) - set(list(extract.keys())) if invalid_model_param: - raise EstimatorException("Stan model definition does not contain required parameters") + raise EstimatorException("Pyro model definition does not contain required parameters") # `stan.optimizing` automatically returns all defined parameters - # filter out unecessary keys + # filter out unnecessary keys extract = {param: extract[param] for param in model_param_names} return extract diff --git a/orbit/initializer/__init__.py b/orbit/initializer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orbit/initializer/dlt.py b/orbit/initializer/dlt.py new file mode 100644 index 00000000..f70852fe --- /dev/null +++ b/orbit/initializer/dlt.py @@ -0,0 +1,33 @@ +import numpy as np +from ..constants import dlt as constants + + +class DLTInitializer(object): + def __init__(self, s, n_pr, n_nr, n_rr): + self.s = s + self.n_pr = n_pr + self.n_nr = n_nr + self.n_rr = n_rr + + def __call__(self): + init_values = dict() + if self.s > 1: + init_sea = np.random.normal(loc=0, scale=0.05, size=self.s - 1) + # catch cases with extreme values + init_sea[init_sea > 1.0] = 1.0 + init_sea[init_sea < -1.0] = -1.0 + init_values[constants.LatentSamplingParameters.INITIAL_SEASONALITY.value] = init_sea + if self.n_pr > 0: + x = np.random.normal(loc=0, scale=0.1, size=self.n_pr) + x[x < 0] = -1 * x[x < 0] + init_values[constants.LatentSamplingParameters.REGRESSION_POSITIVE_COEFFICIENTS.value] = \ + x + if self.n_nr > 0: + x = np.random.normal(loc=-0, scale=0.1, size=self.n_nr) + x[x > 0] = -1 * x[x > 0] + init_values[constants.LatentSamplingParameters.REGRESSION_NEGATIVE_COEFFICIENTS.value] = \ + x + if self.n_rr > 0: + init_values[constants.LatentSamplingParameters.REGRESSION_REGULAR_COEFFICIENTS.value] = \ + np.random.normal(loc=-0, scale=0.1, size=self.n_rr) + return init_values diff --git a/orbit/initializer/ets.py b/orbit/initializer/ets.py new file mode 100644 index 00000000..62fe6420 --- /dev/null +++ b/orbit/initializer/ets.py @@ -0,0 +1,16 @@ +import numpy as np +from ..constants import ets as constants + + +class ETSInitializer(object): + def __init__(self, s): + self.s = s + + def __call__(self): + init_values = dict() + init_sea = np.random.normal(loc=0, scale=0.05, size=self.s - 1) + # catch cases with extreme values + init_sea[init_sea > 1.0] = 1.0 + init_sea[init_sea < -1.0] = -1.0 + init_values[constants.LatentSamplingParameters.INITIAL_SEASONALITY.value] = init_sea + return init_values diff --git a/orbit/initializer/lgt.py b/orbit/initializer/lgt.py new file mode 100644 index 00000000..3d868cc1 --- /dev/null +++ b/orbit/initializer/lgt.py @@ -0,0 +1,33 @@ +import numpy as np +from ..constants import lgt as constants + + +class LGTInitializer(object): + def __init__(self, s, n_pr, n_nr, n_rr): + self.s = s + self.n_pr = n_pr + self.n_nr = n_nr + self.n_rr = n_rr + + def __call__(self): + init_values = dict() + if self.s > 1: + init_sea = np.random.normal(loc=0, scale=0.05, size=self.s - 1) + # catch cases with extreme values + init_sea[init_sea > 1.0] = 1.0 + init_sea[init_sea < -1.0] = -1.0 + init_values[constants.LatentSamplingParameters.INITIAL_SEASONALITY.value] = init_sea + if self.n_pr > 0: + x = np.random.normal(loc=0, scale=0.1, size=self.n_pr) + x[x < 0] = -1 * x[x < 0] + init_values[constants.LatentSamplingParameters.REGRESSION_POSITIVE_COEFFICIENTS.value] = \ + x + if self.n_nr > 0: + x = np.random.normal(loc=-0, scale=0.1, size=self.n_nr) + x[x > 0] = -1 * x[x > 0] + init_values[constants.LatentSamplingParameters.REGRESSION_NEGATIVE_COEFFICIENTS.value] = \ + x + if self.n_rr > 0: + init_values[constants.LatentSamplingParameters.REGRESSION_REGULAR_COEFFICIENTS.value] = \ + np.random.normal(loc=-0, scale=0.1, size=self.n_rr) + return init_values diff --git a/orbit/models/dlt.py b/orbit/models/dlt.py index 8f8ea899..f4cd5cfe 100644 --- a/orbit/models/dlt.py +++ b/orbit/models/dlt.py @@ -16,6 +16,7 @@ from ..models.ets import BaseETS, ETSMAP, ETSFull, ETSAggregated from ..estimators.stan_estimator import StanEstimatorMCMC, StanEstimatorVI, StanEstimatorMAP from ..exceptions import IllegalArgument, ModelException, PredictionException +from ..initializer.dlt import DLTInitializer from ..utils.general import is_ordered_datetime @@ -145,39 +146,17 @@ def _set_init_values(self): See: https://pystan.readthedocs.io/en/latest/api.htm Overriding :func: `~orbit.models.BaseETS._set_init_values` """ - def init_values_function(s, n_pr, n_nr, n_rr): - init_values = dict() - if s > 1: - init_sea = np.random.normal(loc=0, scale=0.05, size=s - 1) - # catch cases with extreme values - init_sea[init_sea > 1.0] = 1.0 - init_sea[init_sea < -1.0] = -1.0 - init_values['init_sea'] = init_sea - if n_pr > 0: - x = np.random.normal(loc=0, scale=0.1, size=n_pr) - x[x < 0] = -1 * x[x < 0] - init_values['pr_beta'] = x - if n_nr > 0: - x = np.random.normal(loc=-0, scale=0.1, size=n_nr) - x[x > 0] = -1 * x[x > 0] - init_values['nr_beta'] = x - if n_rr > 0: - init_values['rr_beta'] = np.random.normal(loc=-0, scale=0.1, size=n_rr) - return init_values - - seasonality = self._seasonality - # init_values_partial = partial(init_values_callable, seasonality=seasonality) # partialfunc does not work when passed to PyStan because PyStan uses # inspect.getargspec(func) which seems to raise an exception with keyword-only args # caused by using partialfunc - # lambda as an alternative workaround - if seasonality > 1 or self._num_of_regressors > 0: - init_values_callable = lambda: init_values_function( - seasonality, - self._num_of_positive_regressors, - self._num_of_negative_regressors, - self._num_of_regular_regressors) + # lambda does not work in serialization in pickle + # callable object as an alternative workaround + if self._seasonality > 1 or self._num_of_regressors > 0: + init_values_callable = DLTInitializer( + self._seasonality, self._num_of_positive_regressors, self._num_of_negative_regressors, + self._num_of_regular_regressors + ) self._init_values = init_values_callable def _set_additional_trend_attributes(self): diff --git a/orbit/models/ets.py b/orbit/models/ets.py index 645ffb4a..caab3cf6 100644 --- a/orbit/models/ets.py +++ b/orbit/models/ets.py @@ -10,6 +10,7 @@ from ..estimators.stan_estimator import StanEstimatorMCMC, StanEstimatorVI, StanEstimatorMAP from ..estimators.pyro_estimator import PyroEstimatorVI, PyroEstimatorMAP from ..exceptions import IllegalArgument, ModelException, PredictionException +from ..initializer.ets import ETSInitializer from .base_model import BaseModel from ..utils.general import is_ordered_datetime @@ -113,26 +114,16 @@ def _set_init_values(self): """Set init as a callable (for Stan ONLY) See: https://pystan.readthedocs.io/en/latest/api.htm """ - def init_values_function(s): - init_values = dict() - if s > 1: - init_sea = np.random.normal(loc=0, scale=0.05, size=s-1) - # catch cases with extreme values - init_sea[init_sea > 1.0] = 1.0 - init_sea[init_sea < -1.0] = -1.0 - init_values['init_sea'] = init_sea - - return init_values - - seasonality = self._seasonality - # init_values_partial = partial(init_values_callable, seasonality=seasonality) # partialfunc does not work when passed to PyStan because PyStan uses # inspect.getargspec(func) which seems to raise an exception with keyword-only args # caused by using partialfunc - # lambda as an alternative workaround - if seasonality > 1: - init_values_callable = lambda: init_values_function(seasonality) # noqa + # lambda does not work in serialization in pickle + # callable object as an alternative workaround + if self._seasonality > 1: + init_values_callable = ETSInitializer( + self._seasonality + ) self._init_values = init_values_callable def _set_static_data_attributes(self): diff --git a/orbit/models/lgt.py b/orbit/models/lgt.py index f0728a8a..4b7a372a 100644 --- a/orbit/models/lgt.py +++ b/orbit/models/lgt.py @@ -17,6 +17,7 @@ from ..estimators.stan_estimator import StanEstimatorMCMC, StanEstimatorVI, StanEstimatorMAP from ..estimators.pyro_estimator import PyroEstimatorVI, PyroEstimatorMAP from ..exceptions import IllegalArgument, ModelException, PredictionException +from ..initializer.lgt import LGTInitializer from ..utils.general import is_ordered_datetime @@ -129,39 +130,16 @@ def _set_init_values(self): See: https://pystan.readthedocs.io/en/latest/api.htm Overriding :func: `~orbit.models.BaseETS._set_init_values` """ - def init_values_function(s, n_pr, n_nr, n_rr): - init_values = dict() - if s > 1: - init_sea = np.random.normal(loc=0, scale=0.05, size=s - 1) - # catch cases with extreme values - init_sea[init_sea > 1.0] = 1.0 - init_sea[init_sea < -1.0] = -1.0 - init_values['init_sea'] = init_sea - if n_pr > 0: - x = np.random.normal(loc=0, scale=0.1, size=n_pr) - x[x < 0] = -1 * x[x < 0] - init_values['pr_beta'] = x - if n_nr > 0: - x = np.random.normal(loc=-0, scale=0.1, size=n_nr) - x[x > 0] = -1 * x[x > 0] - init_values['nr_beta'] = x - if n_rr > 0: - init_values['rr_beta'] = np.random.normal(loc=-0, scale=0.1, size=n_rr) - return init_values - - seasonality = self._seasonality - # init_values_partial = partial(init_values_callable, seasonality=seasonality) # partialfunc does not work when passed to PyStan because PyStan uses # inspect.getargspec(func) which seems to raise an exception with keyword-only args # caused by using partialfunc # lambda as an alternative workaround - if seasonality > 1 or self._num_of_regressors > 0: - init_values_callable = lambda: init_values_function( - seasonality, - self._num_of_positive_regressors, - self._num_of_negative_regressors, - self._num_of_regular_regressors) + if self._seasonality > 1 or self._num_of_regressors > 0: + init_values_callable = LGTInitializer( + self._seasonality, self._num_of_positive_regressors, self._num_of_negative_regressors, + self._num_of_regular_regressors + ) self._init_values = init_values_callable def _set_additional_trend_attributes(self): @@ -672,3 +650,5 @@ def __init__(self, **kwargs): def get_regression_coefs(self): return super().get_regression_coefs(aggregate_method=PredictMethod.MAP.value) + +