Skip to content

Commit

Permalink
2.1.0 Change fit_params into a builder function to better support cal…
Browse files Browse the repository at this point in the history
…lback functions (#14)

without using a function for `fit_params` the lightgbm early stopping
callback object is reused across run, which is wrong
  • Loading branch information
kingychiu authored Dec 22, 2023
1 parent 16602d4 commit 5d8a012
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 18 deletions.
2 changes: 1 addition & 1 deletion benchmarks/run_tabular_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def write_report(reports: List):
importance_dfs = compute(
model_cls=model_cls,
model_cls_params=model_cls_params,
model_fit_params={},
model_fit_params=lambda _: {},
X=X_train,
y=y_train,
num_actual_runs=num_actual_runs,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "target-permutation-importances"
version = "2.0.0"
version = "2.1.0"
description = "Compute (Target) Permutation Importances of a machine learning model"
authors = [{name = "Anthony Chiu", email = "kingychiu@gmail.com"}]
maintainers = [{name = "Anthony Chiu", email = "kingychiu@gmail.com"}]
Expand All @@ -24,7 +24,7 @@ classifiers = [

[tool.poetry]
name = "target-permutation-importances"
version = "2.0.0"
version = "2.1.0"
description = "Compute (Target) Permutation Importances of a machine learning model"
authors = ["Anthony Chiu <kingychiu@gmail.com>"]
maintainers = ["Anthony Chiu <kingychiu@gmail.com>"]
Expand Down
13 changes: 10 additions & 3 deletions target_permutation_importances/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from target_permutation_importances.typing import (
ModelBuilderType,
ModelFitParamsBuilderType,
ModelFitterType,
ModelImportanceGetter,
PermutationImportanceCalculatorType,
Expand Down Expand Up @@ -328,7 +329,7 @@ def _get_model_importances_attr(model: Any):
def compute(
model_cls: Any,
model_cls_params: Dict,
model_fit_params: Dict,
model_fit_params: Union[ModelFitParamsBuilderType, Dict],
X: XType,
y: YType,
num_actual_runs: PositiveInt = 2,
Expand All @@ -344,7 +345,7 @@ def compute(
Args:
model_cls: The constructor/class of the model.
model_cls_params: The parameters to pass to the model constructor.
model_fit_params: The parameters to pass to the model fit method.
model_fit_params: A Dict or A function that return parameters to pass to the model fit method.
X: The input data.
y: The target vector.
num_actual_runs: Number of actual runs. Defaults to 2.
Expand Down Expand Up @@ -423,7 +424,13 @@ def _model_builder(is_random_run: bool, run_idx: int) -> Any:
return model_cls(**_model_cls_params)

def _model_fitter(model: Any, X: XType, y: YType) -> Any:
_model_fit_params = model_fit_params.copy()
if isinstance(model_fit_params, dict): # pragma: no cover
_model_fit_params = model_fit_params.copy()
else:
# Assume it is a function
_model_fit_params = model_fit_params(
list(X.columns) if isinstance(X, pd.DataFrame) else None,
)
if "Cat" in str(model.__class__):
_model_fit_params["verbose"] = False
return model.fit(X, y, **_model_fit_params)
Expand Down
2 changes: 1 addition & 1 deletion target_permutation_importances/sklearn_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def fit(
result = compute(
model_cls=self.model_cls,
model_cls_params=self.model_cls_params,
model_fit_params=fit_params,
model_fit_params=lambda _: fit_params,
X=X,
y=y,
num_actual_runs=self.num_actual_runs,
Expand Down
8 changes: 7 additions & 1 deletion target_permutation_importances/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
from beartype import vale
from beartype.typing import Any, List, Union, runtime_checkable
from beartype.typing import Any, List, Optional, Union, runtime_checkable
from typing_extensions import Annotated, Protocol

XType = Union[np.ndarray, pd.DataFrame]
Expand Down Expand Up @@ -41,6 +41,12 @@ def __call__(self, is_random_run: bool, run_idx: int) -> YType:
...


@runtime_checkable
class ModelFitParamsBuilderType(Protocol): # pragma: no cover
def __call__(self, feature_columns: Optional[List[str]]) -> dict:
...


@runtime_checkable
class ModelBuilderType(Protocol): # pragma: no cover
"""
Expand Down
20 changes: 10 additions & 10 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_compute_binary_classification(model_cls, imp_func, xtype):
result_df = compute(
model_cls=model_cls[0],
model_cls_params=model_cls[1],
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=imp_func,
X=X,
y=data.target,
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_compute_multi_class_classification(model_cls, imp_func, xtype):
result_df = compute(
model_cls=model_cls[0],
model_cls_params=model_cls[1],
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=imp_func,
X=X,
y=data.target,
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_compute_multi_label_classification(model_cls, imp_func, xtype):
result_df = compute(
model_cls=model_cls[0],
model_cls_params=model_cls_params,
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=imp_func,
X=X,
y=y,
Expand Down Expand Up @@ -185,7 +185,7 @@ def test_compute_multi_label_classification_with_MultiOutputClassifier(
model_cls_params={
"estimator": model_cls[0](**model_cls[1]),
},
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=imp_func,
X=X,
y=y,
Expand Down Expand Up @@ -216,7 +216,7 @@ def test_compute_regression(model_cls, imp_func, xtype):
result_df = compute(
model_cls=model_cls[0],
model_cls_params=model_cls[1],
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=imp_func,
X=X,
y=data.target,
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_compute_multi_target_regression_with_MultiOutputRegressor(
model_cls_params={
"estimator": model_cls[0](**model_cls[1]),
},
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=imp_func,
X=X,
y=y,
Expand All @@ -282,7 +282,7 @@ def test_compute_with_multiple_importance_functions():
result_dfs = compute(
model_cls=RandomForestClassifier,
model_cls_params={"n_estimators": 2, "n_jobs": 1},
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=[
compute_permutation_importance_by_subtraction,
compute_permutation_importance_by_division,
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_invalid_compute():
compute(
model_cls=RandomForestClassifier,
model_cls_params={},
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=compute_permutation_importance_by_subtraction,
X=1,
y=data.target,
Expand All @@ -382,7 +382,7 @@ def test_invalid_compute():
compute(
model_cls=RandomForestClassifier,
model_cls_params={},
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=compute_permutation_importance_by_subtraction,
X=Xpd,
y=1,
Expand All @@ -393,7 +393,7 @@ def test_invalid_compute():
compute(
model_cls=RandomForestClassifier,
model_cls_params={},
model_fit_params={},
model_fit_params=lambda _: {},
permutation_importance_calculator=compute_permutation_importance_by_subtraction,
X=Xpd,
y=data.target,
Expand Down

0 comments on commit 5d8a012

Please sign in to comment.