Skip to content

Commit

Permalink
193 add isolation forrest conversion to treeexplainer (#289)
Browse files Browse the repository at this point in the history
* Add basic implementation Isolation Forest

* Update conversion isoforest

* Update both conversion versions

* Update test iso tree and own scoring

* Update validation method

* Comment out benchmark for import route issue

* started check

* Add scaling

* Use updated values in convert_isolation_tree_shap_isotree

* ran check

* Disable value_updating

* Clean up isoforest support

* Reactivate benchmarks

* Lint and format

* Add tests for isolation forest conversion

* adds check agains shap implementation

* deleted check_shap.py

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>
  • Loading branch information
r-visser and mmschlk authored Dec 17, 2024
1 parent 1bb3ef7 commit 3220977
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 8 deletions.
95 changes: 95 additions & 0 deletions shapiq/explainer/tree/conversion/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import numpy as np
from sklearn.ensemble._iforest import _average_path_length

from shapiq.utils import safe_isinstance
from shapiq.utils.types import Model
Expand Down Expand Up @@ -74,3 +75,97 @@ def convert_sklearn_tree(
empty_prediction=None, # compute empty prediction later
original_output_type=output_type,
)


def average_path_length(isolation_forest):
max_samples = isolation_forest._max_samples
average_path_length = _average_path_length(
[max_samples]
) # NOTE: _average_path_length func is equivalent to equation 1 in Isolation Forest paper Lui2008
return average_path_length


def convert_sklearn_isolation_forest(
tree_model: Model,
) -> list[TreeModel]:
"""Transforms a scikit-learn isolation forest to the format used by shapiq.
Args:
tree_model: The scikit-learn isolation forest model to convert.
Returns:
The converted isolation forest model.
"""
scaling = 1.0 / len(tree_model.estimators_)

return [
# convert_isolation_tree_shap_isotree(tree, features, scaling=scaling)
convert_isolation_tree(tree, features, scaling=scaling)
for tree, features in zip(tree_model.estimators_, tree_model.estimators_features_)
]


def convert_isolation_tree(
tree_model: Model,
tree_features,
class_label: Optional[int] = None,
scaling: float = 1.0,
average_path_length: float = 1.0, # TODO fix default value
) -> TreeModel:
"""Convert a scikit-learn decision tree to the format used by shapiq.
Args:
tree_model: The scikit-learn decision tree model to convert.
class_label: The class label of the model to explain. Only used for classification models.
Defaults to ``1``.
scaling: The scaling factor for the tree values.
Returns:
The converted decision tree model.
"""
output_type = "raw"
tree_values = tree_model.tree_.value.copy()
tree_values = tree_values.flatten()
features_updated, values_updated = isotree_value_traversal(
tree_model.tree_, tree_features, normalize=False, scaling=1.0
)
values_updated = values_updated * scaling
values_updated = values_updated.flatten()

return TreeModel(
children_left=tree_model.tree_.children_left,
children_right=tree_model.tree_.children_right,
features=features_updated,
thresholds=tree_model.tree_.threshold,
values=values_updated,
node_sample_weight=tree_model.tree_.weighted_n_node_samples,
empty_prediction=None, # compute empty prediction later
original_output_type=output_type,
)


def isotree_value_traversal(
tree, tree_features, normalize=False, scaling=1.0, data=None, data_missing=None
):
features = tree.feature.copy()
corrected_values = tree.value.copy()
if safe_isinstance(tree, "sklearn.tree._tree.Tree"):

def _recalculate_value(tree, i, level):
if tree.children_left[i] == -1 and tree.children_right[i] == -1:
value = level + _average_path_length(np.array([tree.n_node_samples[i]]))[0]
corrected_values[i, 0] = value
return value * tree.n_node_samples[i]
else:
value_left = _recalculate_value(tree, tree.children_left[i], level + 1)
value_right = _recalculate_value(tree, tree.children_right[i], level + 1)
corrected_values[i, 0] = (value_left + value_right) / tree.n_node_samples[i]
return value_left + value_right

_recalculate_value(tree, 0, 0)
if normalize:
corrected_values = (corrected_values.T / corrected_values.sum(1)).T
corrected_values = corrected_values * scaling
# re-number the features if each tree gets a different set of features
features = np.where(features >= 0, tree_features[features], features)
return features, corrected_values
3 changes: 2 additions & 1 deletion shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class TreeExplainer(Explainer):

def __init__(
self,
model: Union[dict, TreeModel, Any],
model: Union[dict, TreeModel, list, Any],
max_order: int = 2,
min_order: int = 1,
index: str = "k-SII",
Expand All @@ -61,6 +61,7 @@ def __init__(
# validate and parse model
validated_model = validate_tree_model(model, class_label=class_index)
self._trees: list[TreeModel] = copy.deepcopy(validated_model)
# TODO trees are made instance of list here, but in validation they are also but then converted back into single element if list is length 1
if not isinstance(self._trees, list):
self._trees = [self._trees]
self._n_trees = len(self._trees)
Expand Down
19 changes: 16 additions & 3 deletions shapiq/explainer/tree/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from .base import TreeModel
from .conversion.lightgbm import convert_lightgbm_booster
from .conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree
from .conversion.sklearn import (
convert_sklearn_forest,
convert_sklearn_isolation_forest,
convert_sklearn_tree,
)
from .conversion.xgboost import convert_xgboost_booster

SUPPORTED_MODELS = {
Expand All @@ -20,6 +24,8 @@
"sklearn.ensemble._forest.ExtraTreesClassifier",
"sklearn.ensemble.RandomForestRegressor",
"sklearn.ensemble._forest.RandomForestRegressor",
"sklearn.ensemble.IsolationForest",
"sklearn.ensemble._iforest.IsolationForest",
"lightgbm.sklearn.LGBMRegressor",
"lightgbm.sklearn.LGBMClassifier",
"lightgbm.basic.Booster",
Expand All @@ -42,8 +48,11 @@ def validate_tree_model(
# tree model (is already in the correct format)
if type(model).__name__ == "TreeModel":
tree_model = model
elif isinstance(model, list) and all([type(m).__name__ == "TreeModel" for m in model]):
tree_model = model
# direct return if list of tree models
elif type(model).__name__ == "list":
# check if all elements are TreeModel
if all([type(tree).__name__ == "TreeModel" for tree in model]):
tree_model = model
# dict as model is parsed to TreeModel (the dict needs to have the correct format and names)
elif type(model).__name__ == "dict":
tree_model = TreeModel(**model)
Expand All @@ -66,6 +75,10 @@ def validate_tree_model(
or safe_isinstance(model, "sklearn.ensemble._forest.ExtraTreesClassifier")
):
tree_model = convert_sklearn_forest(model, class_label=class_label)
elif safe_isinstance(model, "sklearn.ensemble.IsolationForest") or safe_isinstance(
model, "sklearn.ensemble._iforest.IsolationForest"
):
tree_model = convert_sklearn_isolation_forest(model)
elif safe_isinstance(model, "lightgbm.sklearn.LGBMRegressor") or safe_isinstance(
model, "lightgbm.sklearn.LGBMClassifier"
):
Expand Down
2 changes: 1 addition & 1 deletion shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def plot_force(
Returns:
The force plot as a matplotlib figure (if show is ``False``).
"""
from shapiq import force_plot
from .plot import force_plot

return force_plot(
self,
Expand Down
21 changes: 20 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from PIL import Image
from sklearn.datasets import make_classification, make_regression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.ensemble import IsolationForest, RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

Expand Down Expand Up @@ -142,6 +142,25 @@ def rf_clf_model() -> RandomForestClassifier:
return model


# Isolationforest model
@pytest.fixture
def if_clf_model() -> IsolationForest:
n_samples, n_outliers = 120, 40
rng = np.random.RandomState(0)
covariance = np.array([[0.5, -0.1], [0.7, 0.4]])
cluster_1 = 0.4 * rng.randn(n_samples, 2) @ covariance + np.array([2, 2]) # general
cluster_2 = 0.3 * rng.randn(n_samples, 2) + np.array([-2, -2]) # spherical
outliers = rng.uniform(low=-4, high=4, size=(n_outliers, 2))

X = np.concatenate([cluster_1, cluster_2, outliers])
y = np.concatenate([np.ones((2 * n_samples), dtype=int), -np.ones((n_outliers), dtype=int)])

# X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
model = IsolationForest(random_state=42, n_estimators=3)
model.fit(X, y)
return model


@pytest.fixture
def xgb_reg_model():
"""Return a simple xgboost regression model."""
Expand Down
26 changes: 26 additions & 0 deletions tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,29 @@ def test_xgboost_shap_error(xgb_clf_model, background_clf_data):

# now the values surprisingly are the same
assert np.allclose(sv_shap, sv_shapiq_rounded_values, rtol=1e-5)


def test_iso_forest_shap(if_clf_model):
"""Tests the shapiq implementation of TreeSHAP vs. SHAP's implementation for Isolation Forest."""

x_explain = np.array([0.125, 0.05])

# the following code is used to get the shap values from the SHAP implementation
# import shap
# model_copy = copy.deepcopy(if_clf_model)
# explainer_shap = shap.TreeExplainer(model=model_copy)
# baseline_shap = float(explainer_shap.expected_value)
# sv_shap = explainer_shap.shap_values(x_explain)
# print(sv_shap)
# print(baseline_shap)
sv_shap = np.array([-2.34951688, -4.55545493])
baseline_shap = 12.238305148044713

# compute with shapiq
explainer_shapiq = TreeExplainer(model=if_clf_model, max_order=1, index="SV")
sv_shapiq = explainer_shapiq.explain(x=x_explain)
sv_shapiq_values = sv_shapiq.get_n_order_values(1)
baseline_shapiq = sv_shapiq.baseline_value

assert baseline_shap == pytest.approx(baseline_shapiq, rel=1e-6)
assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from shapiq.explainer.tree.base import TreeModel
from shapiq.explainer.tree.conversion.edges import create_edge_tree
from shapiq.explainer.tree.conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree
from shapiq.explainer.tree.conversion.sklearn import (
convert_sklearn_forest,
convert_sklearn_isolation_forest,
convert_sklearn_tree,
)
from shapiq.utils import safe_isinstance


Expand Down Expand Up @@ -123,3 +127,14 @@ def test_skleanr_rf_conversion(rf_clf_model, rf_reg_model):
assert isinstance(tree_model, list)
assert safe_isinstance(tree_model[0], tree_model_class_path_str)
assert tree_model[0].empty_prediction is not None


def test_sklearn_if_conversion(if_clf_model):
"""Test the conversion of a scikit-learn isolation forest model."""
tree_model_class_path_str = ["shapiq.explainer.tree.base.TreeModel"]

# test the isolation forest model
tree_model = convert_sklearn_isolation_forest(if_clf_model)
assert isinstance(tree_model, list)
assert safe_isinstance(tree_model[0], tree_model_class_path_str)
assert tree_model[0].empty_prediction is not None
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from shapiq.explainer.tree.validation import validate_tree_model


def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model):
def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model, if_clf_model):
"""Test the validation of the model."""
class_path_str = ["shapiq.explainer.tree.base.TreeModel"]
# sklearn dt models are supported
Expand All @@ -20,6 +20,10 @@ def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model):
for tree in tree_model:
assert safe_isinstance(tree, class_path_str)
tree_model = validate_tree_model(rf_reg_model)
for tree in tree_model:
assert safe_isinstance(tree, class_path_str)
# sklearn isolation forest is supported
tree_model = validate_tree_model(if_clf_model)
for tree in tree_model:
assert safe_isinstance(tree, class_path_str)

Expand Down

0 comments on commit 3220977

Please sign in to comment.