From 5c3c21826f5cb6cbdd5423db9156841ae942ad31 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 21 Mar 2024 16:55:53 +0100 Subject: [PATCH 1/6] adds logit (log_odds) support for TreeExplainer and closes #56 --- shapiq/explainer/tree/base.py | 80 ++++++++++++++++--- shapiq/explainer/tree/conversion/sklearn.py | 20 ++--- shapiq/explainer/tree/explainer.py | 6 +- shapiq/explainer/tree/treeshapiq.py | 4 +- shapiq/explainer/tree/validation.py | 41 ++++++++-- tests/conftest.py | 22 +++++ .../test_tree_explainer.py | 39 ++++++++- .../test_tree_explainer_validate.py | 56 +++++++++++-- .../test_tree_treeshapiq.py | 3 +- 9 files changed, 227 insertions(+), 44 deletions(-) diff --git a/shapiq/explainer/tree/base.py b/shapiq/explainer/tree/base.py index 1d3afa56..2a2e8d7d 100644 --- a/shapiq/explainer/tree/base.py +++ b/shapiq/explainer/tree/base.py @@ -27,6 +27,26 @@ class TreeModel: the empty prediction is computed from the leaf values and the sample weights. leaf_mask: The boolean mask of the leaf nodes in a tree. The default value is None. Then the leaf mask is computed from the children left and right arrays. + n_features_in_tree: The number of features in the tree model. The default value is None. + Then the number of features in the tree model is computed from the unique feature + indices in the features array. + max_feature_id: The maximum feature index in the tree model. The default value is None. Then + the maximum feature index in the tree model is computed from the features array. + feature_ids: The feature indices of the decision nodes in the tree model. The default value + is None. Then the feature indices of the decision nodes in the tree model are computed + from the unique feature indices in the features array. + root_node_id: The root node id of the tree model. The default value is None. Then the root + node id of the tree model is set to 0. + n_nodes: The number of nodes in the tree model. The default value is None. Then the number + of nodes in the tree model is computed from the children left array. + nodes: The node ids of the tree model. The default value is None. Then the node ids of the + tree model are computed from the number of nodes in the tree model. + feature_map_original_internal: A mapping of feature indices from the original feature + indices (as in the model) to the internal feature indices (as in the tree model). + feature_map_internal_original: A mapping of feature indices from the internal feature + indices (as in the tree model) to the original feature indices (as in the model). + original_output_type: The original output type of the tree model. The default value is + "raw". """ children_left: np.ndarray[int] @@ -43,12 +63,23 @@ class TreeModel: root_node_id: Optional[int] = None n_nodes: Optional[int] = None nodes: Optional[np.ndarray[int]] = None - feature_mapping_old_new: Optional[dict] = None - feature_mapping_new_old: Optional[dict] = None + feature_map_original_internal: Optional[dict[int, int]] = None + feature_map_internal_original: Optional[dict[int, int]] = None + original_output_type: str = "raw" def __getitem__(self, item) -> Any: return getattr(self, item) + def compute_empty_prediction(self) -> None: + """Compute the empty prediction of the tree model. + + The method computes the empty prediction of the tree model by taking the weighted average of + the leaf node values. The method modifies the tree model in place. + """ + self.empty_prediction = compute_empty_prediction( + self.values[self.leaf_mask], self.node_sample_weight[self.leaf_mask] + ) + def __post_init__(self) -> None: # setup leaf mask if self.leaf_mask is None: @@ -59,9 +90,7 @@ def __post_init__(self) -> None: self.thresholds = np.where(self.leaf_mask, np.nan, self.thresholds) # setup empty prediction if self.empty_prediction is None: - self.empty_prediction = compute_empty_prediction( - self.values[self.leaf_mask], self.node_sample_weight[self.leaf_mask] - ) + self.compute_empty_prediction() unique_features = set(np.unique(self.features)) unique_features.discard(-2) # remove leaf node "features" # setup number of features @@ -83,11 +112,11 @@ def __post_init__(self) -> None: if self.nodes is None: self.nodes = np.arange(self.n_nodes) # setup original feature mapping - if self.feature_mapping_old_new is None: - self.feature_mapping_old_new = {i: i for i in unique_features} + if self.feature_map_original_internal is None: + self.feature_map_original_internal = {i: i for i in unique_features} # setup new feature mapping - if self.feature_mapping_new_old is None: - self.feature_mapping_new_old = {i: i for i in unique_features} + if self.feature_map_internal_original is None: + self.feature_map_internal_original = {i: i for i in unique_features} def reduce_feature_complexity(self) -> None: """Reduces the feature complexity of the tree model. @@ -119,8 +148,8 @@ def reduce_feature_complexity(self) -> None: new_features[i] = new_value self.features = new_features self.feature_ids = new_feature_ids - self.feature_mapping_old_new = mapping_old_new - self.feature_mapping_new_old = mapping_new_old + self.feature_map_original_internal = mapping_old_new + self.feature_map_internal_original = mapping_new_old self.n_features_in_tree = len(new_feature_ids) self.max_feature_id = self.n_features_in_tree - 1 @@ -131,6 +160,8 @@ class EdgeTree: The dataclass stores the information of an edge representation of the tree in a way that is easy to access and manipulate for the TreeSHAP-IQ algorithm. + + # TODO: add more information about the attributes """ parents: np.ndarray[int] @@ -153,3 +184,30 @@ def __post_init__(self) -> None: # setup has ancestors if self.has_ancestors is None: self.has_ancestors = self.ancestors > -1 + + +def convert_tree_output_type(tree_model: TreeModel, output_type: str) -> TreeModel: + """Convert the output type of the tree model. + + Args: + tree_model: The tree model to convert. + output_type: The output type to convert the tree model to. Can be "raw", "probability", or + "logit". + + Returns: + The converted tree model. + """ + original_output_type = tree_model.original_output_type + if original_output_type == output_type or output_type == "raw": # no conversion needed + return tree_model + # transform probability to logit + if original_output_type == "probability" and output_type == "logit": + tree_model.values = np.log(tree_model.values / (1 - tree_model.values)) + tree_model.compute_empty_prediction() # recompute the empty prediction + tree_model.original_output_type = output_type + # transform logit to probability + if original_output_type == "logit" and output_type == "probability": + tree_model.values = 1 / (1 + np.exp(-tree_model.values)) + tree_model.compute_empty_prediction() # recompute the empty prediction + tree_model.original_output_type = output_type + return tree_model diff --git a/shapiq/explainer/tree/conversion/sklearn.py b/shapiq/explainer/tree/conversion/sklearn.py index 5d7ec2a3..6c488b6e 100644 --- a/shapiq/explainer/tree/conversion/sklearn.py +++ b/shapiq/explainer/tree/conversion/sklearn.py @@ -13,7 +13,6 @@ def convert_sklearn_forest( tree_model: Model, class_label: Optional[int] = None, - output_type: str = "raw", ) -> list[TreeModel]: """Transforms a scikit-learn random forest to the format used by shapiq. @@ -21,26 +20,19 @@ def convert_sklearn_forest( tree_model: The scikit-learn random forest model to convert. class_label: The class label of the model to explain. Only used for classification models. Defaults to 0. - output_type: Denotes if the tree output values should be transformed or not. Defaults - to None ('raw'). Possible values are 'raw', 'probability', and 'logits'. Returns: The converted random forest model. """ scaling = 1.0 / len(tree_model.estimators_) return [ - convert_sklearn_tree( - tree, scaling=scaling, class_label=class_label, output_type=output_type - ) + convert_sklearn_tree(tree, scaling=scaling, class_label=class_label) for tree in tree_model.estimators_ ] def convert_sklearn_tree( - tree_model: Model, - class_label: Optional[int] = None, - scaling: float = 1.0, - output_type: str = "raw", + tree_model: Model, class_label: Optional[int] = None, scaling: float = 1.0 ) -> TreeModel: """Convert a scikit-learn decision tree to the format used by shapiq. @@ -49,12 +41,11 @@ def convert_sklearn_tree( class_label: The class label of the model to explain. Only used for classification models. Defaults to 0. scaling: The scaling factor for the tree values. - output_type: Denotes if the tree output values should be transformed or not. Defaults - to None ('raw'). Possible values are 'raw', 'probability', and 'logits'. Returns: The converted decision tree model. """ + output_type = "raw" tree_values = tree_model.tree_.value.copy() * scaling # set class label if not given and model is a classifier if safe_isinstance(tree_model, "sklearn.tree.DecisionTreeClassifier") and class_label is None: @@ -66,9 +57,7 @@ def convert_sklearn_tree( tree_values = tree_values[:, 0, :] tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True) tree_values = tree_values[:, class_label] - if output_type != "raw": - # TODO: Add support for logits output type - raise NotImplementedError("Only raw output types are currently supported.") + output_type = "probability" tree_values = tree_values.flatten() return TreeModel( children_left=tree_model.tree_.children_left, @@ -78,4 +67,5 @@ def convert_sklearn_tree( values=tree_values, node_sample_weight=tree_model.tree_.weighted_n_node_samples, empty_prediction=None, # compute empty prediction later + original_output_type=output_type, ) diff --git a/shapiq/explainer/tree/explainer.py b/shapiq/explainer/tree/explainer.py index bfaf0ded..f8a0c837 100644 --- a/shapiq/explainer/tree/explainer.py +++ b/shapiq/explainer/tree/explainer.py @@ -8,7 +8,7 @@ from interaction_values import InteractionValues from .treeshapiq import TreeModel, TreeSHAPIQ -from .validation import _validate_model +from .validation import validate_tree_model class TreeExplainer(Explainer): @@ -21,7 +21,9 @@ def __init__( output_type: str = "raw", ) -> None: # validate and parse model - validated_model = _validate_model(model, class_label=class_label, output_type=output_type) + validated_model = validate_tree_model( + model, class_label=class_label, output_type=output_type + ) self._trees: Union[TreeModel, list[TreeModel]] = copy.deepcopy(validated_model) if not isinstance(self._trees, list): self._trees = [self._trees] diff --git a/shapiq/explainer/tree/treeshapiq.py b/shapiq/explainer/tree/treeshapiq.py index d254abb8..e8abf10d 100644 --- a/shapiq/explainer/tree/treeshapiq.py +++ b/shapiq/explainer/tree/treeshapiq.py @@ -12,7 +12,7 @@ from .base import EdgeTree, TreeModel from .conversion.edges import create_edge_tree -from .validation import _validate_model +from .validation import validate_tree_model class TreeSHAPIQ: @@ -45,7 +45,7 @@ def __init__( self._interaction_type: str = interaction_type # validate and parse model - validated_model = _validate_model(model) # the parsed and validated model + validated_model = validate_tree_model(model) # the parsed and validated model # TODO: add support for other sample weights self._tree: TreeModel = copy.deepcopy(validated_model) self._relevant_features: np.ndarray = np.array(list(self._tree.feature_ids), dtype=int) diff --git a/shapiq/explainer/tree/validation.py b/shapiq/explainer/tree/validation.py index 68f969a5..be9e6af5 100644 --- a/shapiq/explainer/tree/validation.py +++ b/shapiq/explainer/tree/validation.py @@ -3,7 +3,7 @@ from shapiq.utils import safe_isinstance -from .base import TreeModel +from .base import TreeModel, convert_tree_output_type from .conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree SUPPORTED_MODELS = { @@ -14,7 +14,7 @@ } -def _validate_model( +def validate_tree_model( model: Any, class_label: Optional[int] = None, output_type: str = "raw" ) -> Union[TreeModel, list[TreeModel]]: """Validate the model. @@ -27,21 +27,50 @@ def _validate_model( Returns: The validated model and the model function. """ + if output_type not in ["raw", "probability", "logit"]: + raise ValueError( + "Invalid output type. Supported output types are: 'raw', 'probability', 'logit'." + ) + + # direct returns for base tree models and dict as model # tree model (is already in the correct format) if type(model).__name__ == "TreeModel": return model # dict as model is parsed to TreeModel (the dict needs to have the correct format and names) if type(model).__name__ == "dict": return TreeModel(**model) + + # transformation of common machine learning libraries to TreeModel # sklearn decision trees if safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor") or safe_isinstance( model, "sklearn.tree.DecisionTreeClassifier" ): - return convert_sklearn_tree(model, class_label=class_label, output_type=output_type) + tree_model = convert_sklearn_tree(model, class_label=class_label) # sklearn random forests - if safe_isinstance(model, "sklearn.ensemble.RandomForestRegressor") or safe_isinstance( + elif safe_isinstance(model, "sklearn.ensemble.RandomForestRegressor") or safe_isinstance( model, "sklearn.ensemble.RandomForestClassifier" ): - return convert_sklearn_forest(model, class_label=class_label, output_type=output_type) + tree_model = convert_sklearn_forest(model, class_label=class_label) # unsupported model - raise TypeError("Unsupported model type." f"Supported models are: {SUPPORTED_MODELS}") + else: + raise TypeError("Unsupported model type." f"Supported models are: {SUPPORTED_MODELS}") + + # if single tree model put it in a list + if not isinstance(tree_model, list): + tree_model = [tree_model] + + # adapt output type if necessary + if output_type != "raw": + # check if the output type of the tree model is the same as the requested output type + trees_to_adapt = [] + for i, tree in enumerate(tree_model): + if tree.original_output_type != output_type: + trees_to_adapt.append(i) + if trees_to_adapt: + for i in trees_to_adapt: + tree_model[i] = convert_tree_output_type(tree_model[i], output_type) + + if len(tree_model) == 1: + tree_model = tree_model[0] + + return tree_model diff --git a/tests/conftest.py b/tests/conftest.py index cb5fd379..6e16e374 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,8 @@ from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier +from shapiq.explainer.tree import TreeModel + @pytest.fixture def dt_reg_model() -> DecisionTreeRegressor: @@ -35,6 +37,26 @@ def dt_clf_model() -> DecisionTreeClassifier: return model +@pytest.fixture +def dt_clf_model_tree_model() -> TreeModel: + """Return a simple decision tree as a TreeModel.""" + from shapiq.explainer.tree.validation import validate_tree_model + + X, y = make_classification( + n_samples=100, + n_features=7, + random_state=42, + n_classes=3, + n_informative=7, + n_repeated=0, + n_redundant=0, + ) + model = DecisionTreeClassifier(random_state=42, max_depth=3) + model.fit(X, y) + tree_model = validate_tree_model(model) + return tree_model + + @pytest.fixture def rf_reg_model() -> RandomForestRegressor: """Return a simple random forest model.""" diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py index fbcd857d..8abebcfc 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from explainer.tree import TreeModel from shapiq.explainer.tree import TreeExplainer @@ -20,7 +21,7 @@ def test_decision_tree_classifier(dt_clf_model, background_clf_data): assert True # check with invalid output type - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError): _ = TreeExplainer( model=dt_clf_model, max_order=2, min_order=1, output_type="invalid_output_type" ) @@ -54,3 +55,39 @@ def test_random_forrest_classification(rf_clf_model, background_clf_data): explanation = explainer.explain(x_explain) assert type(explanation).__name__ == "InteractionValues" # check correct return type + + +def test_against_shap_implementation(): + """Test the tree explainer against the shap implementation's tree explainer results.""" + # manual values for a tree to test against the shap implementation + children_left = np.asarray([1, 2, 3, -1, -1, -1, 7, -1, -1]) + children_right = np.asarray([6, 5, 4, -1, -1, -1, 8, -1, -1]) + features = np.asarray([0, 1, 0, -2, -2, -2, 2, -2, -2]) + thresholds = np.asarray([0, 0, -0.5, -2, -2, -2, 0, -2, -2]) + node_sample_weight = np.asarray([100, 50, 38, 15, 23, 12, 50, 20, 30]) + + # create a classification tree model + values = [110, 105, 95, 20, 50, 100, 75, 10, 40] + values = [values[i] / max(values) for i in range(len(values))] + values = np.asarray(values) + print(values) + + x_explain = np.asarray([-1, -0.5, 1, 0]) + + tree_model = TreeModel( + children_left=children_left, + children_right=children_right, + features=features, + thresholds=thresholds, + node_sample_weight=node_sample_weight, + values=values, + original_output_type="probability", + ) + + explainer = TreeExplainer(model=tree_model, max_order=1, min_order=1) + explanation = explainer.explain(x_explain) + + assert explanation[(0,)] == pytest.approx(-0.09263158, abs=1e-4) + assert explanation[(1,)] == pytest.approx(-0.12100478, abs=1e-4) + assert explanation[(2,)] == pytest.approx(0.02727273, abs=1e-4) + assert explanation[(3,)] == pytest.approx(0.0, abs=1e-4) diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py index 466e676b..2ad25f8f 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py @@ -1,20 +1,64 @@ """This test module contains all tests for the validation functions of the tree explainer implementation.""" +import copy + import pytest +import numpy as np from shapiq import safe_isinstance -from shapiq.explainer.tree.validation import _validate_model +from shapiq.explainer.tree.validation import validate_tree_model -def test_validate_model(dt_clf_model, dt_reg_model): +def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model): """Test the validation of the model.""" class_path_str = ["explainer.tree.base.TreeModel"] # sklearn dt models are supported - tree_model = _validate_model(dt_clf_model) + tree_model = validate_tree_model(dt_clf_model) assert safe_isinstance(tree_model, class_path_str) - tree_model = _validate_model(dt_reg_model) + tree_model = validate_tree_model(dt_reg_model) assert safe_isinstance(tree_model, class_path_str) + # sklearn rf models are supported + tree_model = validate_tree_model(rf_clf_model) + for tree in tree_model: + assert safe_isinstance(tree, class_path_str) + tree_model = validate_tree_model(rf_reg_model) + for tree in tree_model: + assert safe_isinstance(tree, class_path_str) - # finally, test the unsupported model + # test the unsupported model with pytest.raises(TypeError): - _validate_model("unsupported_model") + validate_tree_model("unsupported_model") + + +def test_validate_output_types(dt_clf_model, dt_clf_model_tree_model): + class_path_str = ["explainer.tree.base.TreeModel"] + + # test with invalid output type + with pytest.raises(ValueError): + validate_tree_model(dt_clf_model, output_type="invalid_output_type") + + # test with 'raw' output type + tree_model = validate_tree_model(dt_clf_model, output_type="raw") + assert safe_isinstance(tree_model, class_path_str) + + # test with 'probability' output type (probability from probability) + tree_model = validate_tree_model(dt_clf_model, output_type="probability") + assert safe_isinstance(tree_model, class_path_str) + + # test with 'logit' output type (logit from probability) + tree_model = validate_tree_model(dt_clf_model, output_type="logit") + assert safe_isinstance(tree_model, class_path_str) + + from shapiq.explainer.tree.base import convert_tree_output_type + + # test with 'probability from 'logit' output type + tree_model_logit = copy.deepcopy(dt_clf_model_tree_model) + tree_model_logit.original_output_type = "logit" + # manually change the values to logit from probabilities + tree_model_logit.values = np.log(tree_model_logit.values / (1 - tree_model_logit.values)) + tree_model_logit = convert_tree_output_type(tree_model_logit, output_type="probability") + assert safe_isinstance(tree_model_logit, class_path_str) + + # test edge cases + tree_model = convert_tree_output_type(dt_clf_model_tree_model, output_type="raw") + assert safe_isinstance(tree_model, class_path_str) diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py index b1717ced..f82c94b6 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py @@ -81,7 +81,8 @@ def test_init(dt_clf_model, background_clf_data): ), ], ) -def test_manual_tree(index: str, expected: dict): +def test_against_old_treeshapiq_implementation(index: str, expected: dict): + """Test the tree explainer against the old TreeSHAP-IQ implementation's results.""" # manual values for a tree to test against the original treeshapiq implementation children_left = np.asarray([1, 2, 3, -1, -1, -1, 7, -1, -1]) children_right = np.asarray([6, 5, 4, -1, -1, -1, 8, -1, -1]) From bd4a97df78d02f1669609697930177272fe27bc7 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 22 Mar 2024 13:46:20 +0100 Subject: [PATCH 2/6] updates documentation to TreeSHAP-IQ --- shapiq/explainer/tree/treeshapiq.py | 74 +++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 9 deletions(-) diff --git a/shapiq/explainer/tree/treeshapiq.py b/shapiq/explainer/tree/treeshapiq.py index e8abf10d..bdfa8b5c 100644 --- a/shapiq/explainer/tree/treeshapiq.py +++ b/shapiq/explainer/tree/treeshapiq.py @@ -17,16 +17,33 @@ class TreeSHAPIQ: """ - The explainer for tree-based models using the TreeSHAP-IQ algorithm. + The explainer for tree-based models using the TreeSHAP-IQ algorithm. For a detailed presentation + of the algorithm, see the original paper: https://arxiv.org/abs/2401.12069. + + TreeSHAP-IQ is an algorithm for computing Shapley Interaction values for tree-based models. + It is heavily based on the Linear TreeSHAP algorithm (outlined in https://proceedings.neurips.cc/paper_files/paper/2022/hash/a5a3b1ef79520b7cd122d888673a3ebc-Abstract-Conference.html) + but extended to compute Shapley Interaction values up to a given order. TreeSHAP-IQ needs to + visit each node only once and makes use of polynomial arithmetic to compute the Shapley + Interaction values efficiently. Args: - model: The tree-based model to explain as an explainer.tree.conversion.TreeModel. + model: A single tree-based model to explain. Note unlike the TreeExplainer class, + TreeSHAP-IQ only supports a single tree model. The tree model can be a dictionary + representation of the tree, a `TreeModel` object, or any other tree model supported by + the `shapiq.explainer.tree.validation.validate_tree_model` function. max_order: The maximum interaction order to be computed. An interaction order of 1 corresponds to the Shapley value. Any value higher than 1 computes the Shapley - interactions values up to that order. Defaults to 2. + interaction values up to that order. Defaults to 2. min_order: The minimum interaction order to be computed. Defaults to 1. interaction_type: The type of interaction to be computed. The interaction type can be - "k-SII" (default) or "SII", "STI", "FSI", "BZF". + "k-SII" (default), "SII", "STI", "FSI", or "BZF". All indices apart from "BZF" will + reduce to the "SV" (Shapley value) for order 1. + verbose: Whether to print information about the tree during initialization. Defaults to + False. + + Note: + This class is not intended to be used directly. Instead, use the `TreeExplainer` class to + explain tree-based models which internally uses then the TreeSHAP-IQ algorithm. """ def __init__( @@ -151,6 +168,22 @@ def _compute_shapley_interaction_values( quotient_poly_down: np.ndarray[float] = None, depth: int = 0, ) -> None: + """Computes the Shapley Interaction values for a given instance x_explain and interaction + order. This function is called recursively for each node in the tree. + + Args: + x_explain: The instance to be explained. + order: The interaction order for which the Shapley Interaction values should be + computed. Defaults to 1. + node_id: The node ID of the current node in the tree. Defaults to 0. + summary_poly_down: The summary polynomial for the current node. Defaults to None (init). + summary_poly_up: The summary polynomial propagated up the tree. Defaults to None (init). + interaction_poly_down: The interaction polynomial for the current node. Defaults to + None (init). + quotient_poly_down: The quotient polynomial for the current node. Defaults to None + (init). + depth: The depth of the current node in the tree. Defaults to 0. + """ # reset activations for new calculations if node_id == 0: self._activations.fill(False) @@ -282,7 +315,7 @@ def _compute_shapley_interaction_values( self.Ns_id[self.n_interpolation_size, : self.n_interpolation_size], ) interaction_update *= self._psi( - summary_poly_up[depth, :], # TODO fix: wrong at this point (should not be zero) + summary_poly_up[depth, :], D_power, quotient_poly_down[depth, interactions_seen], self.Ns, @@ -326,7 +359,7 @@ def _compute_shapley_interaction_values( self.shapley_interactions[interactions_with_ancestor_to_update] -= update @staticmethod - def _psi(E, D_power, quotient_poly, Ns, degree): + def _psi(E, D_power, quotient_poly, Ns, degree) -> np.ndarray[float]: # TODO: add docstring d = degree + 1 n = Ns[d, :d] @@ -355,7 +388,21 @@ def _get_polynomials( summary_poly_up: np.ndarray[float] = None, interaction_poly_down: np.ndarray[float] = None, quotient_poly_down: np.ndarray[float] = None, - ): + ) -> tuple[np.ndarray[float], np.ndarray[float], np.ndarray[float], np.ndarray[float]]: + """Retrieves the polynomials for a given interaction order. It initializes the polynomials + for the first call of the recursive explanation function. + + Args: + order: The interaction order for which the polynomials should be loaded. + summary_poly_down: The summary polynomial for the current node. Defaults to None. + summary_poly_up: The summary polynomial propagated up the tree. Defaults to None. + interaction_poly_down: The interaction polynomial for the current node. Defaults to None. + quotient_poly_down: The quotient polynomial for the current node. Defaults to None. + + Returns: + tuple: The summary polynomial down, the summary polynomial up, the interaction polynomial + down, and the quotient polynomial down. + """ if summary_poly_down is None: summary_poly_down = np.zeros((self._edge_tree.max_depth + 1, self.n_interpolation_size)) summary_poly_down[0, :] = 1 @@ -381,7 +428,7 @@ def _get_polynomials( quotient_poly_down[0, :] = 1 return summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down - def _prepare_variables_for_order(self, interaction_order: int): + def _prepare_variables_for_order(self, interaction_order: int) -> None: """Retrieves the precomputed variables for a given interaction order. This function is called before the recursive explanation function is called. @@ -458,7 +505,16 @@ def _precalculate_interaction_ancestors( self, interaction_order, n_features ) -> dict[int, np.ndarray]: """Calculates the position of the ancestors of the interactions for the tree for a given - order of interactions.""" + order of interactions. + + Args: + interaction_order: The interaction order for which the ancestors should be computed. + n_features: The number of features in the model. + + Returns: + subset_ancestors: A dictionary containing the ancestors of the interactions for each + node in the tree. + """ # stores position of interactions counter_interaction = 0 From 41bf459975e4ac66ed7ae27261a7f6995f1883da Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 22 Mar 2024 13:55:43 +0100 Subject: [PATCH 3/6] add interaction_type as parameter to TreeExplainer --- shapiq/explainer/tree/explainer.py | 3 ++- shapiq/explainer/tree/treeshapiq.py | 5 ++++ .../test_tree_treeshapiq.py | 25 +++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/shapiq/explainer/tree/explainer.py b/shapiq/explainer/tree/explainer.py index f8a0c837..28efb677 100644 --- a/shapiq/explainer/tree/explainer.py +++ b/shapiq/explainer/tree/explainer.py @@ -17,6 +17,7 @@ def __init__( model: Union[dict, TreeModel, Any], max_order: int = 2, min_order: int = 1, + interaction_type: str = "SII", # TODO: add tests and fix bug with k-SII not working class_label: Optional[int] = None, output_type: str = "raw", ) -> None: @@ -36,7 +37,7 @@ def __init__( # setup explainers for all trees self._treeshapiq_explainers: list[TreeSHAPIQ] = [ - TreeSHAPIQ(model=_tree, max_order=self._max_order, interaction_type="SII") + TreeSHAPIQ(model=_tree, max_order=self._max_order, interaction_type=interaction_type) for _tree in self._trees ] diff --git a/shapiq/explainer/tree/treeshapiq.py b/shapiq/explainer/tree/treeshapiq.py index bdfa8b5c..49fdeb62 100644 --- a/shapiq/explainer/tree/treeshapiq.py +++ b/shapiq/explainer/tree/treeshapiq.py @@ -57,6 +57,11 @@ def __init__( # set parameters self._root_node_id = 0 self.verbose = verbose + if max_order < min_order or max_order < 1 or min_order < 1: + raise ValueError( + "The maximum order must be greater than the minimum order and both must be greater " + "than 0." + ) self._max_order: int = max_order self._min_order: int = min_order self._interaction_type: str = interaction_type diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py index f82c94b6..436f5ad5 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py @@ -109,3 +109,28 @@ def test_against_old_treeshapiq_implementation(index: str, expected: dict): for key, value in expected.items(): assert np.isclose(explanation[key], value, atol=1e-5) + + +def test_edge_case_params(): + """Test the TreeSHAPIQ class with edge case parameters.""" + children_left = np.asarray([1, 2, 3, -1, -1, -1, 7, -1, -1]) + children_right = np.asarray([6, 5, 4, -1, -1, -1, 8, -1, -1]) + features = np.asarray([0, 1, 0, -2, -2, -2, 2, -2, -2]) + thresholds = np.asarray([0, 0, -0.5, -2, -2, -2, 0, -2, -2]) + node_sample_weight = np.asarray([100, 50, 38, 15, 23, 12, 50, 20, 30]) + values = np.asarray([110, 105, 95, 20, 50, 100, 75, 10, 40]) + + #x_explain = np.asarray([-1, -0.5, 1, 0]) + + tree_model = TreeModel( + children_left=children_left, + children_right=children_right, + features=features, + thresholds=thresholds, + node_sample_weight=node_sample_weight, + values=values, + ) + + # test with max_order = 0 + with pytest.raises(ValueError): + _ = TreeSHAPIQ(model=tree_model, max_order=0) From 66bde9e728fcf07ef8f592538f384b7186f07687 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 22 Mar 2024 13:55:47 +0100 Subject: [PATCH 4/6] add interaction_type as parameter to TreeExplainer --- .../tests_tree_explainer/test_tree_treeshapiq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py index 436f5ad5..6f844ae0 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py @@ -120,7 +120,7 @@ def test_edge_case_params(): node_sample_weight = np.asarray([100, 50, 38, 15, 23, 12, 50, 20, 30]) values = np.asarray([110, 105, 95, 20, 50, 100, 75, 10, 40]) - #x_explain = np.asarray([-1, -0.5, 1, 0]) + # x_explain = np.asarray([-1, -0.5, 1, 0]) tree_model = TreeModel( children_left=children_left, From a51a0a7357fb2d9578c1c19b9a4319776bc2b59b Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 28 Mar 2024 10:48:20 +0100 Subject: [PATCH 5/6] adds sum, len, and getitem with index to InteractionValues --- shapiq/interaction_values.py | 15 ++++++++++++-- tests/test_base_interaction_values.py | 29 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/shapiq/interaction_values.py b/shapiq/interaction_values.py index c6b2dfab..3fe257e8 100644 --- a/shapiq/interaction_values.py +++ b/shapiq/interaction_values.py @@ -70,15 +70,26 @@ def __str__(self) -> str: """Returns the string representation of the InteractionValues object.""" return self.__repr__() - def __getitem__(self, item: tuple[int, ...]) -> float: + def __len__(self) -> int: + """Returns the length of the InteractionValues object.""" + return len(self.values) # might better to return the theoretical no. of interactions + + def __iter__(self) -> np.nditer: + """Returns an iterator over the values of the InteractionValues object.""" + return np.nditer(self.values) + + def __getitem__(self, item: Union[int, tuple[int, ...]]) -> float: """Returns the score for the given interaction. Args: - item: The interaction for which to return the score. + item: The interaction as a tuple of integers for which to return the score. If `item` is + an integer it serves as the index to the values vector. Returns: The interaction value. If the interaction is not present zero is returned. """ + if isinstance(item, int): + return float(self.values[item]) item = tuple(sorted(item)) try: return float(self.values[self.interaction_lookup[item]]) diff --git a/tests/test_base_interaction_values.py b/tests/test_base_interaction_values.py index f0f8696f..726b2f19 100644 --- a/tests/test_base_interaction_values.py +++ b/tests/test_base_interaction_values.py @@ -89,6 +89,13 @@ def test_initialization(index, n, min_order, max_order, estimation_budget, estim # check getitem with invalid interaction (not in interaction_lookup) assert interaction_values[(100, 101)] == 0 # invalid interaction is 0 + # test getitem with integer as input + assert interaction_values[0] == interaction_values.values[0] + assert interaction_values[-1] == interaction_values.values[-1] + + # test __len__ + assert len(interaction_values) == len(interaction_values.values) + def test_add(): """Tests the __add__ method of the InteractionValues dataclass.""" @@ -226,3 +233,25 @@ def test_mul(): assert np.all(interaction_values_mul.values == 2 * interaction_values_first.values) interaction_values_mul = interaction_values_first * 2.0 assert np.all(interaction_values_mul.values == 2.0 * interaction_values_first.values) + + +def test_sum(): + """Tests the sum method of the InteractionValues dataclass.""" + index = "SII" + n = 5 + min_order = 1 + max_order = 2 + interaction_lookup = { + interaction: i for i, interaction in enumerate(powerset(range(n), min_order, max_order)) + } + values = np.random.rand(len(interaction_lookup)) + interaction_values = InteractionValues( + values=values, + index=index, + n_players=n, + min_order=min_order, + max_order=max_order, + interaction_lookup=interaction_lookup, + ) + + assert np.isclose(sum(interaction_values), np.sum(interaction_values.values)) From 3e47828c3c2025675dc386f343193d1ced4e8f48 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 28 Mar 2024 10:48:41 +0100 Subject: [PATCH 6/6] adds tests to logit/probit transformation --- shapiq/explainer/tree/base.py | 18 ++++-- shapiq/explainer/tree/validation.py | 23 +++++-- .../test_tree_explainer.py | 61 ++++++++++++++++++- .../test_tree_explainer_validate.py | 12 +++- .../test_tree_treeshapiq.py | 2 - 5 files changed, 100 insertions(+), 16 deletions(-) diff --git a/shapiq/explainer/tree/base.py b/shapiq/explainer/tree/base.py index 2a2e8d7d..7fac1ca8 100644 --- a/shapiq/explainer/tree/base.py +++ b/shapiq/explainer/tree/base.py @@ -186,7 +186,7 @@ def __post_init__(self) -> None: self.has_ancestors = self.ancestors > -1 -def convert_tree_output_type(tree_model: TreeModel, output_type: str) -> TreeModel: +def convert_tree_output_type(tree_model: TreeModel, output_type: str) -> tuple[TreeModel, bool]: """Convert the output type of the tree model. Args: @@ -195,14 +195,24 @@ def convert_tree_output_type(tree_model: TreeModel, output_type: str) -> TreeMod "logit". Returns: - The converted tree model. + The converted tree model and a warning flag indicating whether invalid probability values + were adjusted in logit transformation. """ + warning_flag = False original_output_type = tree_model.original_output_type if original_output_type == output_type or output_type == "raw": # no conversion needed - return tree_model + return tree_model, warning_flag # transform probability to logit if original_output_type == "probability" and output_type == "logit": tree_model.values = np.log(tree_model.values / (1 - tree_model.values)) + # give a warning if leaf values are replaced + if np.any(tree_model.values[tree_model.leaf_mask] == np.inf) or np.any( + tree_model.values[tree_model.leaf_mask] == -np.inf + ): + warning_flag = True + # replace +inf with 14 and -inf with -14 + tree_model.values = np.where(tree_model.values == np.inf, 14, tree_model.values) + tree_model.values = np.where(tree_model.values == -np.inf, -14, tree_model.values) tree_model.compute_empty_prediction() # recompute the empty prediction tree_model.original_output_type = output_type # transform logit to probability @@ -210,4 +220,4 @@ def convert_tree_output_type(tree_model: TreeModel, output_type: str) -> TreeMod tree_model.values = 1 / (1 + np.exp(-tree_model.values)) tree_model.compute_empty_prediction() # recompute the empty prediction tree_model.original_output_type = output_type - return tree_model + return tree_model, warning_flag diff --git a/shapiq/explainer/tree/validation.py b/shapiq/explainer/tree/validation.py index be9e6af5..e12da4f3 100644 --- a/shapiq/explainer/tree/validation.py +++ b/shapiq/explainer/tree/validation.py @@ -1,4 +1,5 @@ """This module contains conversion functions for the tree explainer implementation.""" +import warnings from typing import Any, Optional, Union from shapiq.utils import safe_isinstance @@ -35,14 +36,13 @@ def validate_tree_model( # direct returns for base tree models and dict as model # tree model (is already in the correct format) if type(model).__name__ == "TreeModel": - return model + tree_model = model # dict as model is parsed to TreeModel (the dict needs to have the correct format and names) - if type(model).__name__ == "dict": - return TreeModel(**model) - + elif type(model).__name__ == "dict": + tree_model = TreeModel(**model) # transformation of common machine learning libraries to TreeModel # sklearn decision trees - if safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor") or safe_isinstance( + elif safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor") or safe_isinstance( model, "sklearn.tree.DecisionTreeClassifier" ): tree_model = convert_sklearn_tree(model, class_label=class_label) @@ -67,8 +67,19 @@ def validate_tree_model( if tree.original_output_type != output_type: trees_to_adapt.append(i) if trees_to_adapt: + warn_flag = False for i in trees_to_adapt: - tree_model[i] = convert_tree_output_type(tree_model[i], output_type) + tree_model[i], warn_flag = convert_tree_output_type(tree_model[i], output_type) + warn_flag += warn_flag + # at least one tree model was adapted (invalid probability values were adjusted in + # logit transformation) + if warn_flag: + warnings.warn( + UserWarning( + "Invalid probability values (i.e. p=0 or p=1) were numerically adjusted " + "in logit transformation (+/- inf is set to +/- 14)." + ) + ) if len(tree_model) == 1: tree_model = tree_model[0] diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py index 8abebcfc..49cd26b8 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py @@ -26,6 +26,10 @@ def test_decision_tree_classifier(dt_clf_model, background_clf_data): model=dt_clf_model, max_order=2, min_order=1, output_type="invalid_output_type" ) + explainer = _ = TreeExplainer(model=dt_clf_model, max_order=1, min_order=1, class_label=1) + explanation = explainer.explain(x_explain) + print(explanation) + def test_decision_tree_regression(dt_reg_model, background_reg_data): """Test TreeExplainer with a simple decision tree regressor.""" @@ -84,10 +88,65 @@ def test_against_shap_implementation(): original_output_type="probability", ) - explainer = TreeExplainer(model=tree_model, max_order=1, min_order=1) + explainer = TreeExplainer(model=tree_model, max_order=1, min_order=1, interaction_type="SII") explanation = explainer.explain(x_explain) assert explanation[(0,)] == pytest.approx(-0.09263158, abs=1e-4) assert explanation[(1,)] == pytest.approx(-0.12100478, abs=1e-4) assert explanation[(2,)] == pytest.approx(0.02727273, abs=1e-4) assert explanation[(3,)] == pytest.approx(0.0, abs=1e-4) + + explainer = TreeExplainer(model=tree_model, max_order=1, min_order=1, interaction_type="SII") + explanation = explainer.explain(x_explain) + print(explanation) + print(explainer._treeshapiq_explainers[0]._tree.empty_prediction) + + explainer = TreeExplainer( + model=tree_model, max_order=1, min_order=1, interaction_type="SII", output_type="logit" + ) + explanation = explainer.explain(x_explain) + print(explanation) + print(explainer._treeshapiq_explainers[0]._tree.empty_prediction) + + +def test_logit_probit_conversion(dt_clf_model, background_clf_data): + """This test checks the conversion of the output types for a tree classifier.""" + x_explain = background_clf_data[0] + + # test with 'raw' output type (no change) + explainer_raw = TreeExplainer(model=dt_clf_model, max_order=1, min_order=1, output_type="raw") + explainer_raw_explanation = explainer_raw.explain(x_explain) + explainer_raw_empty_pred = explainer_raw._treeshapiq_explainers[0]._tree.empty_prediction + + # test with 'probability' output type (probability from probability, no change to raw) + explainer_prob = TreeExplainer( + model=dt_clf_model, max_order=1, min_order=1, output_type="probability" + ) + explainer_prob_explanation = explainer_prob.explain(x_explain) + explainer_prob_empty_pred = explainer_prob._treeshapiq_explainers[0]._tree.empty_prediction + + # test with 'logit' output type (logit from probability) + with pytest.warns(UserWarning): + explainer_logit = TreeExplainer( + model=dt_clf_model, max_order=1, min_order=1, output_type="logit" + ) + explainer_logit_explanation = explainer_logit.explain(x_explain) + explainer_logit_empty_pred = explainer_logit._treeshapiq_explainers[0]._tree.empty_prediction + + # make assertions + assert explainer_raw_explanation == explainer_prob_explanation + assert explainer_raw_explanation != explainer_logit_explanation + assert explainer_prob_explanation != explainer_logit_explanation + + # manually transform the probabilities to logits + sum_raw = sum(explainer_raw_explanation) + explainer_raw_empty_pred + sum_prob = sum(explainer_prob_explanation) + explainer_prob_empty_pred + sum_logit = sum(explainer_logit_explanation) + explainer_logit_empty_pred + + manual_logit = np.log(sum_prob / (1 - sum_prob)) + manual_prob = 1 / (1 + np.exp(-sum_logit)) + + assert sum_prob == sum_raw + assert manual_prob == pytest.approx(sum_prob, abs=1e-4) + # logit values explode more and are more difficult to compare + assert manual_logit < 3 and sum_logit < 3 diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py index 2ad25f8f..2856ffdc 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py @@ -30,7 +30,13 @@ def test_validate_model(dt_clf_model, dt_reg_model, rf_reg_model, rf_clf_model): validate_tree_model("unsupported_model") -def test_validate_output_types(dt_clf_model, dt_clf_model_tree_model): +def test_validate_output_types_parameters(dt_clf_model, dt_clf_model_tree_model): + """This test checks weather the correct output types are validated. + + This test does not check if the conversion of the output types is semantically correct. This is + tested in the next test. + + """ class_path_str = ["explainer.tree.base.TreeModel"] # test with invalid output type @@ -56,9 +62,9 @@ def test_validate_output_types(dt_clf_model, dt_clf_model_tree_model): tree_model_logit.original_output_type = "logit" # manually change the values to logit from probabilities tree_model_logit.values = np.log(tree_model_logit.values / (1 - tree_model_logit.values)) - tree_model_logit = convert_tree_output_type(tree_model_logit, output_type="probability") + tree_model_logit, _ = convert_tree_output_type(tree_model_logit, output_type="probability") assert safe_isinstance(tree_model_logit, class_path_str) # test edge cases - tree_model = convert_tree_output_type(dt_clf_model_tree_model, output_type="raw") + tree_model, _ = convert_tree_output_type(dt_clf_model_tree_model, output_type="raw") assert safe_isinstance(tree_model, class_path_str) diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py index 6f844ae0..49e971df 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py @@ -120,8 +120,6 @@ def test_edge_case_params(): node_sample_weight = np.asarray([100, 50, 38, 15, 23, 12, 50, 20, 30]) values = np.asarray([110, 105, 95, 20, 50, 100, 75, 10, 40]) - # x_explain = np.asarray([-1, -0.5, 1, 0]) - tree_model = TreeModel( children_left=children_left, children_right=children_right,