Skip to content

Commit

Permalink
add internal script for validate_data and sklearn version constant, u…
Browse files Browse the repository at this point in the history
…pdate gnb and gwnb scripts accordingly
  • Loading branch information
msamsami committed Dec 25, 2024
1 parent 53ee486 commit b897d75
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
19 changes: 19 additions & 0 deletions wnb/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any

import sklearn
from packaging import version
from sklearn.utils import check_array

__all__ = ["SKLEARN_V1_6_OR_LATER", "validate_data"]


SKLEARN_V1_6_OR_LATER = version.parse(sklearn.__version__) >= version.parse("1.6")


if SKLEARN_V1_6_OR_LATER:
from sklearn.utils.validation import validate_data
else:

def validate_data(estimator, X, **kwargs: Any):
kwargs.pop("reset", None)
return check_array(X, estimator=estimator, **kwargs)
15 changes: 3 additions & 12 deletions wnb/gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,13 @@

import numpy as np
import pandas as pd
import sklearn
from packaging import version
from scipy.special import logsumexp
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.exceptions import DataConversionWarning
from sklearn.utils import as_float_array, check_array
from sklearn.utils import as_float_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_is_fitted

if version.parse(sklearn.__version__) >= version.parse("1.6"):
from sklearn.utils.validation import validate_data
else:

def validate_data(estimator, X, **kwargs):
return check_array(X, estimator=estimator, **kwargs)


if sys.version_info >= (3, 11):
from typing import Self
else:
Expand All @@ -34,6 +24,7 @@ def validate_data(estimator, X, **kwargs):
from wnb.stats.base import DistMixin
from wnb.stats.typing import DistributionLike

from ._utils import SKLEARN_V1_6_OR_LATER, validate_data
from .typing import ArrayLike, Float, MatrixLike

__all__ = ["GeneralNB"]
Expand Down Expand Up @@ -93,7 +84,7 @@ def __init__(
self.distributions = distributions
self.alpha = alpha

if version.parse(sklearn.__version__) >= version.parse("1.6"):
if SKLEARN_V1_6_OR_LATER:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
Expand Down
17 changes: 4 additions & 13 deletions wnb/gwnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,20 @@

import numpy as np
import pandas as pd
import sklearn
from packaging import version
from scipy.special import logsumexp
from scipy.stats import norm
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.exceptions import DataConversionWarning
from sklearn.utils import as_float_array, check_array, deprecated
from sklearn.utils import as_float_array, deprecated
from sklearn.utils.multiclass import check_classification_targets, type_of_target
from sklearn.utils.validation import check_is_fitted

if version.parse(sklearn.__version__) >= version.parse("1.6"):
from sklearn.utils.validation import validate_data
else:

def validate_data(estimator, X, **kwargs):
return check_array(X, estimator=estimator, **kwargs)


if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from ._utils import SKLEARN_V1_6_OR_LATER, validate_data
from .typing import ArrayLike, Float, Int, MatrixLike

__all__ = ["GaussianWNB"]
Expand Down Expand Up @@ -121,7 +112,7 @@ def __init__(
self.C = C
self.learning_hist = learning_hist

if version.parse(sklearn.__version__) >= version.parse("1.6"):
if SKLEARN_V1_6_OR_LATER:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
Expand All @@ -138,7 +129,7 @@ def _check_inputs(self, X, y) -> None:

# Check that the dataset has only two unique labels
if (y_type := type_of_target(y)) != "binary":
if version.parse(sklearn.__version__) >= version.parse("1.6"):
if SKLEARN_V1_6_OR_LATER:
msg = f"Only binary classification is supported. The type of the target is {y_type}."
else:
msg = "Unknown label type: non-binary"
Expand Down

0 comments on commit b897d75

Please sign in to comment.