Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update residual diagnostic plots and example notebook #758

Merged
merged 21 commits into from
Jun 25, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions docs/tutorials/model_diagnostics.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/tutorials/pyro_basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -310,7 +310,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
"version": "3.7.7"
},
"toc": {
"base_numbering": 1,
Expand Down
374 changes: 374 additions & 0 deletions docs/tutorials/residual_diagnostic.ipynb

Large diffs are not rendered by default.

131 changes: 130 additions & 1 deletion orbit/diagnostics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
import math
import os
import pkg_resources
import statsmodels.api as sm
from scipy import stats

from ..constants.constants import PredictionKeys
from orbit.utils.general import is_empty_dataframe, is_ordered_datetime
from ..constants.constants import BacktestFitKeys
from ..constants.palette import PredictionPaletteClassic as PredPal
from orbit.constants import palette
from orbit.diagnostics.metrics import smape
from orbit.utils.plot import orbit_style_decorator

from ..exceptions import PlotException


import logging

logger = logging.getLogger("orbit")
Expand Down Expand Up @@ -715,3 +719,128 @@ def params_comparison_boxplot(
plt.title(title)

return ax


@orbit_style_decorator
def residual_diagnostic_plot(
df,
dist="norm",
date_col="week",
residual_col="residual",
fitted_col="prediction",
sparams=None,
):
"""
Parameters
----------

df : pd.DataFrame
dist : str
date_col : str
column name of date
residual_col : str
column name of residual
fitted_col: str
column name of fitted value from model
sparams : float or list
extra parameters

Notes
-----
1. residual by time
2. residual vs fitted
3. residual histogram with vertical line as mean
4. residuals qq plot
5. residual ACF
6. residual PACF
"""
fig, ax = plt.subplots(3, 2, figsize=(15, 12))

# plot 1 residual by time
sns.lineplot(
x=date_col,
y=residual_col,
data=df,
ax=ax[0, 0],
color=palette.OrbitPalette.BLUE.value,
alpha=0.8,
label="residual",
)
ax[0, 0].set_title("Residual by Time")
ax[0, 0].legend()

# plot 2 residual vs fitted
sns.scatterplot(
x=fitted_col,
y=residual_col,
data=df,
ax=ax[0, 1],
color=palette.OrbitPalette.BLUE.value,
alpha=0.8,
label="residual",
)
ax[0, 1].axhline(
y=0,
linestyle="--",
color=palette.OrbitPalette.BLACK.value,
alpha=0.5,
label="0",
)
ax[0, 1].set_title("Residual vs Fitted")
ax[0, 1].set_xlabel("fitted")
ax[0, 1].legend()

# plot 3 residual histogram with vertical line as mean
sns.distplot(
df[residual_col].values,
hist=True,
kde=True,
ax=ax[1, 0],
color=palette.OrbitPalette.BLUE.value,
label="residual",
hist_kws={
"edgecolor": "white",
"alpha": 0.5,
"facecolor": palette.OrbitPalette.BLUE.value,
},
)
ax[1, 0].set_title("Residual Distribution")
ax[1, 0].axvline(
df[residual_col].mean(),
color=palette.OrbitPalette.ORANGE.value,
linestyle="--",
alpha=0.9,
label="residual mean",
)
ax[1, 0].set_ylabel("density")
ax[1, 0].legend()

# plot 4 residual qq plot
if dist == "norm":
_ = stats.probplot(df[residual_col].values, dist="norm", plot=ax[1, 1])
elif dist == "t-dist":
# t-dist qq-plot
_ = stats.probplot(
df[residual_col].values, dist=stats.t, sparams=sparams, plot=ax[1, 1]
)

# plot 5 residual ACF
sm.graphics.tsa.plot_acf(
df[residual_col].values,
ax=ax[2, 0],
title="Residual ACF",
color=palette.OrbitPalette.BLUE.value,
)
ax[2, 0].set_xlabel("lag")
ax[2, 0].set_ylabel("acf")

# plot 6 residual PACF
sm.graphics.tsa.plot_pacf(
df[residual_col].values,
ax=ax[2, 1],
title="Residual PACF",
color=palette.OrbitPalette.BLUE.value,
)
ax[2, 1].set_xlabel("lag")
ax[2, 1].set_ylabel("pacf")
fig.tight_layout()
3 changes: 3 additions & 0 deletions orbit/models/dlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def DLT(
global_floor : float
Minimum value of global logistic trend. Default is set to 0.0. This value is used only when
`global_trend_option` = 'logistic'
forecast_horizon : int
forecast_horizon will be used only when users want to specify optimization forecast horizon > 1
estimator : string; {'stan-mcmc', 'stan-map'}
default to be 'stan-mcmc'.

Expand Down Expand Up @@ -123,6 +125,7 @@ def DLT(
global_floor=global_floor,
forecast_horizon=forecast_horizon,
)

if estimator == EstimatorsKeys.StanMAP.value:
dlt_forecaster = MAPForecaster(
model=dlt, estimator_type=StanEstimatorMAP, **kwargs
Expand Down
2 changes: 0 additions & 2 deletions orbit/template/dlt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import pandas as pd
from scipy.stats import nct
from statsmodels.api import OLS
from copy import deepcopy
import torch
from enum import Enum
Expand All @@ -19,7 +18,6 @@
from ..exceptions import IllegalArgument, ModelException, PredictionException
from .ets import ETSModel
from ..estimators.stan_estimator import StanEstimatorMCMC, StanEstimatorMAP
from ..estimators.pyro_estimator import PyroEstimatorSVI


class DataInputMapper(Enum):
Expand Down