From c7def0d5922e28ea65426a187135a902664153a1 Mon Sep 17 00:00:00 2001 From: baggiponte <57922983+baggiponte@users.noreply.github.com> Date: Sat, 8 Jun 2024 16:07:03 +0200 Subject: [PATCH] Refactor2024/unstable decorator (#233) --- functime/__init__.py | 4 ---- functime/_utils.py | 14 +++++++++++--- functime/feature_extractors.py | 4 ++-- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/functime/__init__.py b/functime/__init__.py index b379be72..8be6190b 100644 --- a/functime/__init__.py +++ b/functime/__init__.py @@ -2,8 +2,4 @@ __version__ = "0.9.5" -import logging - from functime.feature_extractors import FeatureExtractor # noqa: F401 - -logging.basicConfig(level=logging.INFO) diff --git a/functime/_utils.py b/functime/_utils.py index e6691d63..8b911b74 100644 --- a/functime/_utils.py +++ b/functime/_utils.py @@ -1,13 +1,21 @@ from __future__ import annotations import logging -from typing import Any, Callable +from functools import wraps +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable, ParamSpec, TypeVar + + P = ParamSpec("P") + R = TypeVar("R") logger = logging.getLogger(__name__) -def UseAtOwnRisk(func: Callable) -> Any: - def wrapped_func(*args, **kwargs): +def warn_is_unstable(func: Callable[P, R]) -> Callable[P, R]: + @wraps(func) + def wrapped_func(*args: P.args, **kwargs: P.kwargs) -> R: logger.warning( f"The function {func.__name__} is unstable and untested. Use at your own risk." ) diff --git a/functime/feature_extractors.py b/functime/feature_extractors.py index ddb16e97..39af2092 100644 --- a/functime/feature_extractors.py +++ b/functime/feature_extractors.py @@ -15,7 +15,7 @@ from scipy.spatial import KDTree from functime._functime_rust import rs_faer_lstsq1 -from functime._utils import UseAtOwnRisk +from functime._utils import warn_is_unstable from functime.type_aliases import DetrendMethod # from functime.feature_extractor import FeatureExtractor # noqa: F401 @@ -164,7 +164,7 @@ def approximate_entropy( ApEn = approximate_entropy -@UseAtOwnRisk +@warn_is_unstable def augmented_dickey_fuller(x: TIME_SERIES_T, n_lags: int) -> float: """ Calculates the Augmented Dickey-Fuller (ADF) test statistic. This only works for Series input right now.