diff --git a/eurybia/core/smartdrift.py b/eurybia/core/smartdrift.py index 897eddf..c465669 100644 --- a/eurybia/core/smartdrift.py +++ b/eurybia/core/smartdrift.py @@ -13,6 +13,7 @@ import catboost import pandas as pd +from pandas.api.types import is_datetime64_any_dtype as is_datetime from shapash.explainer.smart_explainer import SmartExplainer from sklearn.metrics import roc_auc_score from sklearn.model_selection import train_test_split @@ -22,7 +23,7 @@ from eurybia.utils.io import load_pickle, save_pickle from eurybia.utils.model_drift import catboost_hyperparameter_init, catboost_hyperparameter_type from eurybia.utils.statistical_tests import chisq_test, compute_js_divergence, ksmirnov_test -from eurybia.utils.utils import base_100 +from eurybia.utils.utils import base_100, convert_date_col_into_multiple_col logging.getLogger("papermill").setLevel(logging.WARNING) logging.getLogger("blib2to3").setLevel(logging.WARNING) @@ -248,6 +249,9 @@ def compile( if sample_size is not None: self.df_baseline = self._sampling(sampling, sample_size, self.df_baseline) self.df_current = self._sampling(sampling, sample_size, self.df_current) + + # Checking datasets + self._check_dataset(ignore_cols) # Consistency analysis pb_cols, err_mods = self._analyze_consistency(full_validation, ignore_cols) @@ -385,6 +389,40 @@ def generate_report( if rm_working_dir: shutil.rmtree(working_dir) + def _check_dataset(self, ignore_cols: list = list()): + """ + Method to check if datasets are correct before to be analysed and if + it's not, try to modify them and informs the user. In worse case raise + an error. + + Parameters + ---------- + full_validation : bool, optional (default: False) + If True, analyze consistency on modalities between columns + ignore_cols: list, optional + list of feature to ignore in compute + """ + + if len([column for column in self.df_current.columns if is_datetime(self.df_current[column])]) > 0: + if self.deployed_model is None: + for col in [column for column in self.df_current.columns if is_datetime(self.df_current[column])]: + print( + f"""Column {col} will be dropped and transformed in df_current by : {col}_year, {col}_month, {col}_day""" + ) + self.df_current = convert_date_col_into_multiple_col(self.df_current) + else: + raise TypeError("df_current have datetime column. You should drop it") + + if len([column for column in self.df_baseline.columns if is_datetime(self.df_baseline[column])]) > 0: + if self.deployed_model is None: + for col in [column for column in self.df_baseline.columns if is_datetime(self.df_baseline[column])]: + print( + f"""Column {col} will be dropped and transformed in df_baseline by : {col}_year, {col}_month, {col}_day""" + ) + self.df_baseline = convert_date_col_into_multiple_col(self.df_baseline) + else: + raise TypeError("df_baseline have datetime column. You should drop it") + def _analyze_consistency(self, full_validation=False, ignore_cols: list = list()): """ method to analyse consistency between the 2 datasets, in terms of columns and modalities @@ -422,6 +460,7 @@ def _analyze_consistency(self, full_validation=False, ignore_cols: list = list() err_dtypes = [ c for c in common_cols if self.df_baseline.dtypes.map(str)[c] != self.df_current.dtypes.map(str)[c] ] + if len(err_dtypes) > 0: print( f"""The following variables have mismatching dtypes diff --git a/eurybia/utils/utils.py b/eurybia/utils/utils.py index 79953ca..2d7f693 100644 --- a/eurybia/utils/utils.py +++ b/eurybia/utils/utils.py @@ -4,6 +4,7 @@ from pathlib import Path import pandas as pd +from pandas.api.types import is_datetime64_any_dtype as is_datetime def convert_string_to_int_keys(input_dict: dict) -> dict: @@ -91,3 +92,33 @@ def round_to_k(x, k): return int(new_x) # Avoid the '.0' that can mislead the user that it may be a round number else: return new_x + + +def convert_date_col_into_multiple_col(df: pd.DataFrame) -> pd.DataFrame: + """ + Transform datetime column into multiple columns + - year + - month + - day + Drop datetime column + Parameters + ---------- + df: pd.Dataframe + input DataFrame with datetime columns + Returns + ------- + pd.Dataframe + DataFrame without datetime columns + """ + + date_col_list = [column for column in df.columns if is_datetime(df[column])] + + for col_date in date_col_list: + df[col_date + "_year"] = df[col_date].dt.year + df[col_date + "_month"] = df[col_date].dt.month + df[col_date + "_day"] = df[col_date].dt.day + + # droping original date column + df = df.drop(col_date, axis=1) + + return df diff --git a/tests/unit_tests/core/test_smartdrift.py b/tests/unit_tests/core/test_smartdrift.py index e0fed75..9f3f879 100644 --- a/tests/unit_tests/core/test_smartdrift.py +++ b/tests/unit_tests/core/test_smartdrift.py @@ -7,7 +7,9 @@ from pathlib import Path from unittest.mock import Mock, patch +import numpy as np import pandas as pd +import pytest import shapash from category_encoders import OrdinalEncoder from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor @@ -469,3 +471,47 @@ def test_define_style(self): assert sd.plot._style_dict["featimportance_colorscale"] == colors_dict["featimportance_colorscale"] assert sd.plot._style_dict["contrib_colorscale"] == colors_dict["contrib_colorscale"] # not testing the shapash.explainer.smart_explainer + + def test_datetime_column_transformation(self): + """ + Test if SmartDrift can automatically handle datatime columns + """ + + date_list = pd.date_range(start="01/01/2022", end="01/30/2022") + X1 = np.random.rand(len(date_list)) + X2 = np.random.rand(len(date_list)) + + df_current = pd.DataFrame(date_list, columns=["date"]) + df_current["col1"] = X1 + df_baseline = pd.DataFrame(date_list, columns=["date"]) + df_baseline["col1"] = X2 + + sd = SmartDrift(df_current=df_current, df_baseline=df_baseline) + sd.compile(full_validation=True) + # Should pass this step + auc = sd.auc + assert auc > 0 + + def test_datetime_column_model_error(self): + """ + Test if SmartDrift raised an error when their is datatime columns + and deployed_model is filled + """ + + date_list = pd.date_range(start="01/01/2022", end="01/30/2022") + X1 = np.random.rand(len(date_list)) + X2 = np.random.rand(len(date_list)) + + df_current = pd.DataFrame(date_list, columns=["date"]) + df_current["col1"] = X1 + df_baseline = pd.DataFrame(date_list, columns=["date"]) + df_baseline["col1"] = X2 + + # Random models + regressor = RandomForestRegressor(n_estimators=2).fit(df_baseline[["col1"]], df_baseline["col1"].ravel()) + + sd = SmartDrift(df_current=df_current, df_baseline=df_baseline, deployed_model=regressor) + + # Should raise an error + with pytest.raises(TypeError, match="df_current have datetime column. You should drop it"): + sd.compile(full_validation=True)