Skip to content

Commit

Permalink
Feat pickle fix (#342)
Browse files Browse the repository at this point in the history
* first pickle fix attempt

* better attempt by using initializer

creating an initializer class to tackle problem of #340 fixed #40

* ETS Initializer

add ets initializer and change some wording from stan to generic

* change wording from stan to generic [minor]

minor wording
  • Loading branch information
Edwin Ng authored Jan 28, 2021
1 parent 307b5b4 commit 2b4aba9
Show file tree
Hide file tree
Showing 14 changed files with 153 additions and 95 deletions.
4 changes: 2 additions & 2 deletions orbit/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class EstimatorOptionsMapper(Enum):
set of options)
"""
ENGINE_TO_SAMPLE = {
'stan': ['map','vi','mcmc'],
'pyro': ['map','vi']
'stan': ['map', 'vi', 'mcmc'],
'pyro': ['map', 'vi']
}
SAMPLE_TO_PREDICT = {
'map': ['map'],
Expand Down
16 changes: 13 additions & 3 deletions orbit/constants/dlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class GlobalTrendOption(Enum):

class BaseSamplingParameters(Enum):
"""
The stan output sampling parameters related with DLT base model.
base parameters in posteriors sampling
"""
# ---------- Common Local Trend ---------- #
LOCAL_TREND_LEVELS = 'l'
Expand All @@ -75,19 +75,29 @@ class GlobalTrendSamplingParameters(Enum):

class SeasonalitySamplingParameters(Enum):
"""
The stan output sampling parameters related with seasonality component.
seasonality component related parameters in posteriors sampling
"""
SEASONALITY_LEVELS = 's'
SEASONALITY_SMOOTHING_FACTOR = 'sea_sm'


class RegressionSamplingParameters(Enum):
"""
The stan output sampling parameters related with regression component.
regression component related parameters in posteriors sampling
"""
REGRESSION_COEFFICIENTS = 'beta'


class LatentSamplingParameters(Enum):
"""
latent variables to be sampled
"""
REGRESSION_POSITIVE_COEFFICIENTS = 'pr_beta'
REGRESSION_NEGATIVE_COEFFICIENTS = 'nr_beta'
REGRESSION_REGULAR_COEFFICIENTS = 'rr_beta'
INITIAL_SEASONALITY = 'init_sea'


class RegressionPenalty(Enum):
fixed_ridge = 0
lasso = 1
Expand Down
14 changes: 10 additions & 4 deletions orbit/constants/ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

class DataInputMapper(Enum):
"""
mapping from object input to stan file
mapping from object input to sampler
"""
# All of the following have default defined in DEFAULT_SLGT_FIT_ATTRIBUTES
# ---------- Data Input ---------- #
# observation related
_NUM_OF_OBSERVATIONS = 'NUM_OF_OBS'
Expand All @@ -21,7 +20,7 @@ class DataInputMapper(Enum):

class BaseSamplingParameters(Enum):
"""
The stan output sampling parameters related with LGT base model.
base parameters in posteriors sampling
"""
# ---------- Common Local Trend ---------- #
LOCAL_TREND_LEVELS = 'l'
Expand All @@ -32,7 +31,14 @@ class BaseSamplingParameters(Enum):

class SeasonalitySamplingParameters(Enum):
"""
The stan output sampling parameters related with seasonality component.
seasonality component related parameters in posteriors sampling
"""
SEASONALITY_LEVELS = 's'
SEASONALITY_SMOOTHING_FACTOR = 'sea_sm'


class LatentSamplingParameters(Enum):
"""
latent variables to be sampled
"""
INITIAL_SEASONALITY = 'init_sea'
16 changes: 13 additions & 3 deletions orbit/constants/lgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DataInputMapper(Enum):

class BaseSamplingParameters(Enum):
"""
The stan output sampling parameters related with LGT base model.
base parameters in posteriors sampling
"""
# ---------- Common Local Trend ---------- #
LOCAL_TREND_LEVELS = 'l'
Expand All @@ -60,19 +60,29 @@ class BaseSamplingParameters(Enum):

class SeasonalitySamplingParameters(Enum):
"""
The stan output sampling parameters related with seasonality component.
seasonality component related parameters in posteriors sampling
"""
SEASONALITY_LEVELS = 's'
SEASONALITY_SMOOTHING_FACTOR = 'sea_sm'


class RegressionSamplingParameters(Enum):
"""
The stan output sampling parameters related with regression component.
regression component related parameters in posteriors sampling
"""
REGRESSION_COEFFICIENTS = 'beta'


class LatentSamplingParameters(Enum):
"""
latent variables to be sampled
"""
REGRESSION_POSITIVE_COEFFICIENTS = 'pr_beta'
REGRESSION_NEGATIVE_COEFFICIENTS = 'nr_beta'
REGRESSION_REGULAR_COEFFICIENTS = 'rr_beta'
INITIAL_SEASONALITY = 'init_sea'


class RegressionPenalty(Enum):
fixed_ridge = 0
lasso = 1
Expand Down
4 changes: 2 additions & 2 deletions orbit/diagnostics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def plot_predicted_data(training_actual_df, predicted_df, date_col, actual_col,
figsize pass through to `matplotlib.pyplot.figure()`
path: str
path to save the figure
fontzise: int
fontsize: int
fontsize of the title
Returns
-------
Expand Down Expand Up @@ -146,7 +146,7 @@ def plot_predicted_components(predicted_df, date_col, prediction_percentiles=Non
figsize pass through to `matplotlib.pyplot.figure()`
path: str; optional
path to save the figure
fontzise: int; optional
fontsize: int; optional
fontsize of the title
is_visible: boolean
whether we want to show the plot. If called from unittest, is_visible might = False.
Expand Down
8 changes: 4 additions & 4 deletions orbit/estimators/base_estimator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABCMeta, abstractmethod
from abc import abstractmethod
import numpy as np


Expand All @@ -10,7 +10,7 @@ class BaseEstimator(object):
seed : int
seed number for initial random values
verbose : bool
If True, output all stan diagnostics messages
If True, output all diagnostics messages from estimators
"""
def __init__(self, seed=8888, verbose=False):
Expand All @@ -27,11 +27,11 @@ def fit(self, model_name, model_param_names, data_input, init_values=None):
Parameters
----------
model_name : str
name of stan model
name of model - used in mapping the right sampling file (stan/pyro/...)
model_param_names : list
list of strings of model parameters names to extract
data_input : dict
key-value pairs of data input as required by definition in stan model
key-value pairs of data input as required by definition in samplers (stan/pyro/...)
init_values : float or np.array
initial sampler value. If None, 'random' is used
Expand Down
8 changes: 4 additions & 4 deletions orbit/estimators/pyro_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def fit(self, model_name, model_param_names, data_input, init_values=None):
# make sure that model param names are a subset of stan extract keys
invalid_model_param = set(model_param_names) - set(list(extract.keys()))
if invalid_model_param:
raise EstimatorException("Stan model definition does not contain required parameters")
raise EstimatorException("Pyro model definition does not contain required parameters")

# `stan.optimizing` automatically returns all defined parameters
# filter out unecessary keys
# filter out unnecessary keys
extract = {param: extract[param] for param in model_param_names}

return extract
Expand Down Expand Up @@ -165,10 +165,10 @@ def fit(self, model_name, model_param_names, data_input, init_values=None):
# make sure that model param names are a subset of stan extract keys
invalid_model_param = set(model_param_names) - set(list(extract.keys()))
if invalid_model_param:
raise EstimatorException("Stan model definition does not contain required parameters")
raise EstimatorException("Pyro model definition does not contain required parameters")

# `stan.optimizing` automatically returns all defined parameters
# filter out unecessary keys
# filter out unnecessary keys
extract = {param: extract[param] for param in model_param_names}

return extract
Empty file added orbit/initializer/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions orbit/initializer/dlt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from ..constants import dlt as constants


class DLTInitializer(object):
def __init__(self, s, n_pr, n_nr, n_rr):
self.s = s
self.n_pr = n_pr
self.n_nr = n_nr
self.n_rr = n_rr

def __call__(self):
init_values = dict()
if self.s > 1:
init_sea = np.random.normal(loc=0, scale=0.05, size=self.s - 1)
# catch cases with extreme values
init_sea[init_sea > 1.0] = 1.0
init_sea[init_sea < -1.0] = -1.0
init_values[constants.LatentSamplingParameters.INITIAL_SEASONALITY.value] = init_sea
if self.n_pr > 0:
x = np.random.normal(loc=0, scale=0.1, size=self.n_pr)
x[x < 0] = -1 * x[x < 0]
init_values[constants.LatentSamplingParameters.REGRESSION_POSITIVE_COEFFICIENTS.value] = \
x
if self.n_nr > 0:
x = np.random.normal(loc=-0, scale=0.1, size=self.n_nr)
x[x > 0] = -1 * x[x > 0]
init_values[constants.LatentSamplingParameters.REGRESSION_NEGATIVE_COEFFICIENTS.value] = \
x
if self.n_rr > 0:
init_values[constants.LatentSamplingParameters.REGRESSION_REGULAR_COEFFICIENTS.value] = \
np.random.normal(loc=-0, scale=0.1, size=self.n_rr)
return init_values
16 changes: 16 additions & 0 deletions orbit/initializer/ets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
from ..constants import ets as constants


class ETSInitializer(object):
def __init__(self, s):
self.s = s

def __call__(self):
init_values = dict()
init_sea = np.random.normal(loc=0, scale=0.05, size=self.s - 1)
# catch cases with extreme values
init_sea[init_sea > 1.0] = 1.0
init_sea[init_sea < -1.0] = -1.0
init_values[constants.LatentSamplingParameters.INITIAL_SEASONALITY.value] = init_sea
return init_values
33 changes: 33 additions & 0 deletions orbit/initializer/lgt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from ..constants import lgt as constants


class LGTInitializer(object):
def __init__(self, s, n_pr, n_nr, n_rr):
self.s = s
self.n_pr = n_pr
self.n_nr = n_nr
self.n_rr = n_rr

def __call__(self):
init_values = dict()
if self.s > 1:
init_sea = np.random.normal(loc=0, scale=0.05, size=self.s - 1)
# catch cases with extreme values
init_sea[init_sea > 1.0] = 1.0
init_sea[init_sea < -1.0] = -1.0
init_values[constants.LatentSamplingParameters.INITIAL_SEASONALITY.value] = init_sea
if self.n_pr > 0:
x = np.random.normal(loc=0, scale=0.1, size=self.n_pr)
x[x < 0] = -1 * x[x < 0]
init_values[constants.LatentSamplingParameters.REGRESSION_POSITIVE_COEFFICIENTS.value] = \
x
if self.n_nr > 0:
x = np.random.normal(loc=-0, scale=0.1, size=self.n_nr)
x[x > 0] = -1 * x[x > 0]
init_values[constants.LatentSamplingParameters.REGRESSION_NEGATIVE_COEFFICIENTS.value] = \
x
if self.n_rr > 0:
init_values[constants.LatentSamplingParameters.REGRESSION_REGULAR_COEFFICIENTS.value] = \
np.random.normal(loc=-0, scale=0.1, size=self.n_rr)
return init_values
37 changes: 8 additions & 29 deletions orbit/models/dlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..models.ets import BaseETS, ETSMAP, ETSFull, ETSAggregated
from ..estimators.stan_estimator import StanEstimatorMCMC, StanEstimatorVI, StanEstimatorMAP
from ..exceptions import IllegalArgument, ModelException, PredictionException
from ..initializer.dlt import DLTInitializer
from ..utils.general import is_ordered_datetime


Expand Down Expand Up @@ -145,39 +146,17 @@ def _set_init_values(self):
See: https://pystan.readthedocs.io/en/latest/api.htm
Overriding :func: `~orbit.models.BaseETS._set_init_values`
"""
def init_values_function(s, n_pr, n_nr, n_rr):
init_values = dict()
if s > 1:
init_sea = np.random.normal(loc=0, scale=0.05, size=s - 1)
# catch cases with extreme values
init_sea[init_sea > 1.0] = 1.0
init_sea[init_sea < -1.0] = -1.0
init_values['init_sea'] = init_sea
if n_pr > 0:
x = np.random.normal(loc=0, scale=0.1, size=n_pr)
x[x < 0] = -1 * x[x < 0]
init_values['pr_beta'] = x
if n_nr > 0:
x = np.random.normal(loc=-0, scale=0.1, size=n_nr)
x[x > 0] = -1 * x[x > 0]
init_values['nr_beta'] = x
if n_rr > 0:
init_values['rr_beta'] = np.random.normal(loc=-0, scale=0.1, size=n_rr)
return init_values

seasonality = self._seasonality

# init_values_partial = partial(init_values_callable, seasonality=seasonality)
# partialfunc does not work when passed to PyStan because PyStan uses
# inspect.getargspec(func) which seems to raise an exception with keyword-only args
# caused by using partialfunc
# lambda as an alternative workaround
if seasonality > 1 or self._num_of_regressors > 0:
init_values_callable = lambda: init_values_function(
seasonality,
self._num_of_positive_regressors,
self._num_of_negative_regressors,
self._num_of_regular_regressors)
# lambda does not work in serialization in pickle
# callable object as an alternative workaround
if self._seasonality > 1 or self._num_of_regressors > 0:
init_values_callable = DLTInitializer(
self._seasonality, self._num_of_positive_regressors, self._num_of_negative_regressors,
self._num_of_regular_regressors
)
self._init_values = init_values_callable

def _set_additional_trend_attributes(self):
Expand Down
23 changes: 7 additions & 16 deletions orbit/models/ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..estimators.stan_estimator import StanEstimatorMCMC, StanEstimatorVI, StanEstimatorMAP
from ..estimators.pyro_estimator import PyroEstimatorVI, PyroEstimatorMAP
from ..exceptions import IllegalArgument, ModelException, PredictionException
from ..initializer.ets import ETSInitializer
from .base_model import BaseModel
from ..utils.general import is_ordered_datetime

Expand Down Expand Up @@ -113,26 +114,16 @@ def _set_init_values(self):
"""Set init as a callable (for Stan ONLY)
See: https://pystan.readthedocs.io/en/latest/api.htm
"""
def init_values_function(s):
init_values = dict()
if s > 1:
init_sea = np.random.normal(loc=0, scale=0.05, size=s-1)
# catch cases with extreme values
init_sea[init_sea > 1.0] = 1.0
init_sea[init_sea < -1.0] = -1.0
init_values['init_sea'] = init_sea

return init_values

seasonality = self._seasonality

# init_values_partial = partial(init_values_callable, seasonality=seasonality)
# partialfunc does not work when passed to PyStan because PyStan uses
# inspect.getargspec(func) which seems to raise an exception with keyword-only args
# caused by using partialfunc
# lambda as an alternative workaround
if seasonality > 1:
init_values_callable = lambda: init_values_function(seasonality) # noqa
# lambda does not work in serialization in pickle
# callable object as an alternative workaround
if self._seasonality > 1:
init_values_callable = ETSInitializer(
self._seasonality
)
self._init_values = init_values_callable

def _set_static_data_attributes(self):
Expand Down
Loading

0 comments on commit 2b4aba9

Please sign in to comment.