Skip to content

Commit

Permalink
Merge pull request #29 from armgilles/feature/transf_date
Browse files Browse the repository at this point in the history
Feature/transf date
  • Loading branch information
ThomasBouche authored Sep 23, 2022
2 parents 51b1369 + 21dfe8d commit f60be26
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 1 deletion.
41 changes: 40 additions & 1 deletion eurybia/core/smartdrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import catboost
import pandas as pd
from pandas.api.types import is_datetime64_any_dtype as is_datetime
from shapash.explainer.smart_explainer import SmartExplainer
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
Expand All @@ -22,7 +23,7 @@
from eurybia.utils.io import load_pickle, save_pickle
from eurybia.utils.model_drift import catboost_hyperparameter_init, catboost_hyperparameter_type
from eurybia.utils.statistical_tests import chisq_test, compute_js_divergence, ksmirnov_test
from eurybia.utils.utils import base_100
from eurybia.utils.utils import base_100, convert_date_col_into_multiple_col

logging.getLogger("papermill").setLevel(logging.WARNING)
logging.getLogger("blib2to3").setLevel(logging.WARNING)
Expand Down Expand Up @@ -248,6 +249,9 @@ def compile(
if sample_size is not None:
self.df_baseline = self._sampling(sampling, sample_size, self.df_baseline)
self.df_current = self._sampling(sampling, sample_size, self.df_current)

# Checking datasets
self._check_dataset(ignore_cols)
# Consistency analysis
pb_cols, err_mods = self._analyze_consistency(full_validation, ignore_cols)

Expand Down Expand Up @@ -385,6 +389,40 @@ def generate_report(
if rm_working_dir:
shutil.rmtree(working_dir)

def _check_dataset(self, ignore_cols: list = list()):
"""
Method to check if datasets are correct before to be analysed and if
it's not, try to modify them and informs the user. In worse case raise
an error.
Parameters
----------
full_validation : bool, optional (default: False)
If True, analyze consistency on modalities between columns
ignore_cols: list, optional
list of feature to ignore in compute
"""

if len([column for column in self.df_current.columns if is_datetime(self.df_current[column])]) > 0:
if self.deployed_model is None:
for col in [column for column in self.df_current.columns if is_datetime(self.df_current[column])]:
print(
f"""Column {col} will be dropped and transformed in df_current by : {col}_year, {col}_month, {col}_day"""
)
self.df_current = convert_date_col_into_multiple_col(self.df_current)
else:
raise TypeError("df_current have datetime column. You should drop it")

if len([column for column in self.df_baseline.columns if is_datetime(self.df_baseline[column])]) > 0:
if self.deployed_model is None:
for col in [column for column in self.df_baseline.columns if is_datetime(self.df_baseline[column])]:
print(
f"""Column {col} will be dropped and transformed in df_baseline by : {col}_year, {col}_month, {col}_day"""
)
self.df_baseline = convert_date_col_into_multiple_col(self.df_baseline)
else:
raise TypeError("df_baseline have datetime column. You should drop it")

def _analyze_consistency(self, full_validation=False, ignore_cols: list = list()):
"""
method to analyse consistency between the 2 datasets, in terms of columns and modalities
Expand Down Expand Up @@ -422,6 +460,7 @@ def _analyze_consistency(self, full_validation=False, ignore_cols: list = list()
err_dtypes = [
c for c in common_cols if self.df_baseline.dtypes.map(str)[c] != self.df_current.dtypes.map(str)[c]
]

if len(err_dtypes) > 0:
print(
f"""The following variables have mismatching dtypes
Expand Down
31 changes: 31 additions & 0 deletions eurybia/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

import pandas as pd
from pandas.api.types import is_datetime64_any_dtype as is_datetime


def convert_string_to_int_keys(input_dict: dict) -> dict:
Expand Down Expand Up @@ -91,3 +92,33 @@ def round_to_k(x, k):
return int(new_x) # Avoid the '.0' that can mislead the user that it may be a round number
else:
return new_x


def convert_date_col_into_multiple_col(df: pd.DataFrame) -> pd.DataFrame:
"""
Transform datetime column into multiple columns
- year
- month
- day
Drop datetime column
Parameters
----------
df: pd.Dataframe
input DataFrame with datetime columns
Returns
-------
pd.Dataframe
DataFrame without datetime columns
"""

date_col_list = [column for column in df.columns if is_datetime(df[column])]

for col_date in date_col_list:
df[col_date + "_year"] = df[col_date].dt.year
df[col_date + "_month"] = df[col_date].dt.month
df[col_date + "_day"] = df[col_date].dt.day

# droping original date column
df = df.drop(col_date, axis=1)

return df
46 changes: 46 additions & 0 deletions tests/unit_tests/core/test_smartdrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from pathlib import Path
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
import pytest
import shapash
from category_encoders import OrdinalEncoder
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
Expand Down Expand Up @@ -469,3 +471,47 @@ def test_define_style(self):
assert sd.plot._style_dict["featimportance_colorscale"] == colors_dict["featimportance_colorscale"]
assert sd.plot._style_dict["contrib_colorscale"] == colors_dict["contrib_colorscale"]
# not testing the shapash.explainer.smart_explainer

def test_datetime_column_transformation(self):
"""
Test if SmartDrift can automatically handle datatime columns
"""

date_list = pd.date_range(start="01/01/2022", end="01/30/2022")
X1 = np.random.rand(len(date_list))
X2 = np.random.rand(len(date_list))

df_current = pd.DataFrame(date_list, columns=["date"])
df_current["col1"] = X1
df_baseline = pd.DataFrame(date_list, columns=["date"])
df_baseline["col1"] = X2

sd = SmartDrift(df_current=df_current, df_baseline=df_baseline)
sd.compile(full_validation=True)
# Should pass this step
auc = sd.auc
assert auc > 0

def test_datetime_column_model_error(self):
"""
Test if SmartDrift raised an error when their is datatime columns
and deployed_model is filled
"""

date_list = pd.date_range(start="01/01/2022", end="01/30/2022")
X1 = np.random.rand(len(date_list))
X2 = np.random.rand(len(date_list))

df_current = pd.DataFrame(date_list, columns=["date"])
df_current["col1"] = X1
df_baseline = pd.DataFrame(date_list, columns=["date"])
df_baseline["col1"] = X2

# Random models
regressor = RandomForestRegressor(n_estimators=2).fit(df_baseline[["col1"]], df_baseline["col1"].ravel())

sd = SmartDrift(df_current=df_current, df_baseline=df_baseline, deployed_model=regressor)

# Should raise an error
with pytest.raises(TypeError, match="df_current have datetime column. You should drop it"):
sd.compile(full_validation=True)

0 comments on commit f60be26

Please sign in to comment.