diff --git a/kats/detectors/prophet_detector.py b/kats/detectors/prophet_detector.py index db60baade..1f7a3e6d6 100644 --- a/kats/detectors/prophet_detector.py +++ b/kats/detectors/prophet_detector.py @@ -138,10 +138,27 @@ class SeasonalityTypes(Enum): EMPTY_LIST: List[SeasonalityTypes] = [] +def to_seasonality(seasonality: Union[str, SeasonalityTypes]) -> SeasonalityTypes: + if isinstance(seasonality, str): + try: + return SeasonalityTypes[seasonality.upper()] + except KeyError: + raise ValueError( + f"Invalid seasonality type: {seasonality}. Valid types are: {list(SeasonalityTypes)}" + ) + elif isinstance(seasonality, SeasonalityTypes): + return seasonality + else: + raise ValueError( + f"Expected string or SeasonalityTypes, got {type(seasonality)} instead" + ) + + def seasonalities_to_dict( seasonalities: Union[ SeasonalityTypes, List[SeasonalityTypes], + List[str], Dict[SeasonalityTypes, Union[bool, str]], ] ) -> Dict[SeasonalityTypes, Union[bool, str]]: @@ -149,7 +166,9 @@ def seasonalities_to_dict( if isinstance(seasonalities, SeasonalityTypes): seasonalities = {seasonalities: True} elif isinstance(seasonalities, list): - seasonalities = {seasonality: True for seasonality in seasonalities} + seasonalities = { + to_seasonality(seasonality): True for seasonality in seasonalities + } elif seasonalities is None: seasonalities = {} return seasonalities @@ -253,6 +272,7 @@ def __init__( Union[ SeasonalityTypes, List[SeasonalityTypes], + List[str], Dict[SeasonalityTypes, Union[bool, str]], ] ] = EMPTY_LIST, diff --git a/kats/tests/detectors/test_prophet_detector.py b/kats/tests/detectors/test_prophet_detector.py index 0374684c4..bab8bf3f7 100644 --- a/kats/tests/detectors/test_prophet_detector.py +++ b/kats/tests/detectors/test_prophet_detector.py @@ -5,6 +5,7 @@ import random from datetime import timedelta +from typing import Union from unittest import TestCase import numpy as np @@ -17,6 +18,7 @@ ProphetScoreFunction, ProphetTrendDetectorModel, SeasonalityTypes, + to_seasonality, ) from kats.utils.simulator import Simulator from parameterized.parameterized import parameterized @@ -872,3 +874,21 @@ def test_pmm_use_case(self) -> None: self.assertEqual( response_wo_historical_data.scores.value.shape, hist_ts.value.shape ) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator `parameter... + @parameterized.expand( + [ + ("day", SeasonalityTypes.DAY), + ("week", SeasonalityTypes.WEEK), + ("weekend", SeasonalityTypes.WEEKEND), + ("year", SeasonalityTypes.YEAR), + (SeasonalityTypes.DAY, SeasonalityTypes.DAY), + (SeasonalityTypes.WEEK, SeasonalityTypes.WEEK), + (SeasonalityTypes.WEEKEND, SeasonalityTypes.WEEKEND), + (SeasonalityTypes.YEAR, SeasonalityTypes.YEAR), + ] + ) + def test_to_seasonality( + self, actual: Union[str, SeasonalityTypes], expected: SeasonalityTypes + ) -> None: + self.assertEqual(to_seasonality(actual), expected)