Skip to content

Commit

Permalink
Fixes #43: Remove kwargs from base transformer. init (#56)
Browse files Browse the repository at this point in the history
* removed **kwargs arg from BaseTransformer.__init__

* added test for unexpected kwarg handling to test_BaseTransformer.py
added similar test to example child class test module
test_DataFrameMethodTransformer.py

* fix broken tests with incorrect kwargs
  • Loading branch information
davidhopkinson26 authored Jan 9, 2023
1 parent 77f65e1 commit 5f3ba58
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 18 deletions.
11 changes: 11 additions & 0 deletions tests/base/test_BaseTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ def test_y_no_rows_error(self):

x.fit(X=df, y=pd.Series(name="b", dtype=object))

def test_unexpected_kwarg_error(self):

with pytest.raises(
TypeError,
match=re.escape(
"__init__() got an unexpected keyword argument 'unexpected_kwarg'"
),
):

BaseTransformer(columns="a", unexpected_kwarg="spanish inquisition")


class TestTransform(object):
"""Tests for BaseTransformer.transform()."""
Expand Down
18 changes: 18 additions & 0 deletions tests/base/test_DataFrameMethodTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tests.test_data as d
import pandas as pd
import numpy as np
import re

import tubular
from tubular.base import DataFrameMethodTransformer
Expand Down Expand Up @@ -167,6 +168,23 @@ def test_attributes_set(self):
msg="Attributes for DataFrameMethodTransformer set in init",
)

def test_unexpected_kwarg_error(self):

with pytest.raises(
TypeError,
match=re.escape(
"__init__() got an unexpected keyword argument 'unexpected_kwarg'"
),
):

DataFrameMethodTransformer(
new_column_name="a",
pd_method_name="sum",
columns=["b", "c"],
drop_original=True,
unexpected_kwarg="spanish inquisition",
)


class TestTransform(object):
"""Tests for DataFrameMethodTransformer.transform()."""
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_NearestMeanResponseImputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_arguments(self):
def test_class_methods(self):
"""Test that NearestMeanResponseImputer has fit and transform methods."""

x = NearestMeanResponseImputer(response_column="c", columns=None)
x = NearestMeanResponseImputer(columns=None)

ta.classes.test_object_method(obj=x, expected_method="fit", msg="fit")

Expand Down
4 changes: 2 additions & 2 deletions tests/nominal/test_GroupRareLevelsTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def test_learnt_values_weight(self):

df = d.create_df_6()

x = GroupRareLevelsTransformer(columns=["b"], cut_off_percent=0.3, weights="a")
x = GroupRareLevelsTransformer(columns=["b"], cut_off_percent=0.3, weight="a")

x.fit(df)

Expand All @@ -235,7 +235,7 @@ def test_learnt_values_weight_2(self):

df = d.create_df_6()

x = GroupRareLevelsTransformer(columns=["c"], cut_off_percent=0.2, weights="a")
x = GroupRareLevelsTransformer(columns=["c"], cut_off_percent=0.2, weight="a")

x.fit(df)

Expand Down
8 changes: 4 additions & 4 deletions tests/nominal/test_MeanResponseTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_arguments(self):
def test_class_methods(self):
"""Test that MeanResponseTransformer has fit and transform methods."""

x = MeanResponseTransformer(response_column="a")
x = MeanResponseTransformer()

ta.classes.test_object_method(obj=x, expected_method="fit", msg="fit")

Expand All @@ -40,7 +40,7 @@ def test_class_methods(self):
def test_inheritance(self):
"""Test that NominalToIntegerTransformer inherits from BaseNominalTransformer."""

x = MeanResponseTransformer(response_column="a")
x = MeanResponseTransformer()

ta.classes.assert_inheritance(x, tubular.nominal.BaseNominalTransformer)

Expand Down Expand Up @@ -127,7 +127,7 @@ def test_check_is_fitted_called(self, mocker):

expected_call_args = {0: {"args": (["global_mean"],), "kwargs": {}}}

x = MeanResponseTransformer(response_column="target")
x = MeanResponseTransformer()

x.fit(pd.DataFrame({"a": ["1", "2"]}), pd.Series([2, 3]))

Expand Down Expand Up @@ -612,7 +612,7 @@ def test_learnt_values_not_modified(self):
def test_expected_output(self, df, expected):
"""Test that the output is expected from transform."""

x = MeanResponseTransformer(response_column="a", columns=["b", "d", "f"])
x = MeanResponseTransformer(columns=["b", "d", "f"])

# set the impute values dict directly rather than fitting x on df so test works with helpers
x.mappings = {
Expand Down
4 changes: 2 additions & 2 deletions tests/nominal/test_OrdinalEncoderTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_arguments(self):
def test_class_methods(self):
"""Test that OrdinalEncoderTransformer has fit and transform methods."""

x = OrdinalEncoderTransformer(response_column="a")
x = OrdinalEncoderTransformer()

ta.classes.test_object_method(obj=x, expected_method="fit", msg="fit")

Expand All @@ -33,7 +33,7 @@ def test_class_methods(self):
def test_inheritance(self):
"""Test that NominalToIntegerTransformer inherits from BaseNominalTransformer."""

x = OrdinalEncoderTransformer(response_column="a")
x = OrdinalEncoderTransformer()

ta.classes.assert_inheritance(x, tubular.nominal.BaseNominalTransformer)
ta.classes.assert_inheritance(x, tubular.mapping.BaseMappingTransformMixin)
Expand Down
2 changes: 0 additions & 2 deletions tests/numeric/test_LogTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def test_base_type_error(self):
LogTransformer(
columns=["a"],
base="a",
new_column_name="b",
)

def test_base_not_strictly_positive_error(self):
Expand All @@ -46,7 +45,6 @@ def test_base_not_strictly_positive_error(self):
LogTransformer(
columns=["a"],
base=0,
new_column_name="b",
)

def test_class_methods(self):
Expand Down
14 changes: 11 additions & 3 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def ListOfTransformers():
imputers.MedianImputer(columns="a"),
imputers.MeanImputer(columns="a"),
imputers.ModeImputer(columns="a"),
imputers.NearestMeanResponseImputer(response_column="a"),
imputers.NearestMeanResponseImputer(columns="a"),
imputers.NullIndicator(columns="a"),
mapping.BaseMappingTransformer(mappings={"a": {1: 2, 3: 4}}),
mapping.BaseMappingTransformMixin(),
Expand All @@ -64,8 +64,8 @@ def ListOfTransformers():
nominal.BaseNominalTransformer(),
nominal.NominalToIntegerTransformer(columns="a"),
nominal.GroupRareLevelsTransformer(columns="a"),
nominal.MeanResponseTransformer(columns="a", response_column="b"),
nominal.OrdinalEncoderTransformer(columns="a", response_column="b"),
nominal.MeanResponseTransformer(columns="a"),
nominal.OrdinalEncoderTransformer(columns="a"),
nominal.OneHotEncodingTransformer(columns="a"),
numeric.LogTransformer(columns="a"),
numeric.CutTransformer(column="a", new_column_name="b"),
Expand Down Expand Up @@ -96,3 +96,11 @@ def test_clone(self, transformer):
"""

b.clone(transformer)

@pytest.mark.parametrize("transformer", ListOfTransformers())
def test_unexpected_kwarg(self, transformer):
"""
Test that transformer can be used in sklearn.base.clone function.
"""

b.clone(transformer)
5 changes: 1 addition & 4 deletions tubular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ class BaseTransformer(TransformerMixin, BaseEstimator):
verbose : bool, default = False
Should statements be printed when methods are run?
**kwds
Arbitrary keyword arguments.
Attributes
----------
columns : list or None
Expand All @@ -54,7 +51,7 @@ def classname(self):
"""Method that returns the name of the current class when called"""
return type(self).__name__

def __init__(self, columns=None, copy=True, verbose=False, **kwargs):
def __init__(self, columns=None, copy=True, verbose=False):

self.version_ = __version__

Expand Down

0 comments on commit 5f3ba58

Please sign in to comment.