Skip to content

Commit

Permalink
Merge pull request #84 from mmschlk/tree-explainer-ouput-spaces
Browse files Browse the repository at this point in the history
Adds logit/probit transformation to tree explainer and closes #56
  • Loading branch information
mmschlk authored Mar 28, 2024
2 parents c3369aa + 3e47828 commit 6dfc249
Show file tree
Hide file tree
Showing 11 changed files with 452 additions and 59 deletions.
90 changes: 79 additions & 11 deletions shapiq/explainer/tree/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ class TreeModel:
the empty prediction is computed from the leaf values and the sample weights.
leaf_mask: The boolean mask of the leaf nodes in a tree. The default value is None. Then the
leaf mask is computed from the children left and right arrays.
n_features_in_tree: The number of features in the tree model. The default value is None.
Then the number of features in the tree model is computed from the unique feature
indices in the features array.
max_feature_id: The maximum feature index in the tree model. The default value is None. Then
the maximum feature index in the tree model is computed from the features array.
feature_ids: The feature indices of the decision nodes in the tree model. The default value
is None. Then the feature indices of the decision nodes in the tree model are computed
from the unique feature indices in the features array.
root_node_id: The root node id of the tree model. The default value is None. Then the root
node id of the tree model is set to 0.
n_nodes: The number of nodes in the tree model. The default value is None. Then the number
of nodes in the tree model is computed from the children left array.
nodes: The node ids of the tree model. The default value is None. Then the node ids of the
tree model are computed from the number of nodes in the tree model.
feature_map_original_internal: A mapping of feature indices from the original feature
indices (as in the model) to the internal feature indices (as in the tree model).
feature_map_internal_original: A mapping of feature indices from the internal feature
indices (as in the tree model) to the original feature indices (as in the model).
original_output_type: The original output type of the tree model. The default value is
"raw".
"""

children_left: np.ndarray[int]
Expand All @@ -43,12 +63,23 @@ class TreeModel:
root_node_id: Optional[int] = None
n_nodes: Optional[int] = None
nodes: Optional[np.ndarray[int]] = None
feature_mapping_old_new: Optional[dict] = None
feature_mapping_new_old: Optional[dict] = None
feature_map_original_internal: Optional[dict[int, int]] = None
feature_map_internal_original: Optional[dict[int, int]] = None
original_output_type: str = "raw"

def __getitem__(self, item) -> Any:
return getattr(self, item)

def compute_empty_prediction(self) -> None:
"""Compute the empty prediction of the tree model.
The method computes the empty prediction of the tree model by taking the weighted average of
the leaf node values. The method modifies the tree model in place.
"""
self.empty_prediction = compute_empty_prediction(
self.values[self.leaf_mask], self.node_sample_weight[self.leaf_mask]
)

def __post_init__(self) -> None:
# setup leaf mask
if self.leaf_mask is None:
Expand All @@ -59,9 +90,7 @@ def __post_init__(self) -> None:
self.thresholds = np.where(self.leaf_mask, np.nan, self.thresholds)
# setup empty prediction
if self.empty_prediction is None:
self.empty_prediction = compute_empty_prediction(
self.values[self.leaf_mask], self.node_sample_weight[self.leaf_mask]
)
self.compute_empty_prediction()
unique_features = set(np.unique(self.features))
unique_features.discard(-2) # remove leaf node "features"
# setup number of features
Expand All @@ -83,11 +112,11 @@ def __post_init__(self) -> None:
if self.nodes is None:
self.nodes = np.arange(self.n_nodes)
# setup original feature mapping
if self.feature_mapping_old_new is None:
self.feature_mapping_old_new = {i: i for i in unique_features}
if self.feature_map_original_internal is None:
self.feature_map_original_internal = {i: i for i in unique_features}
# setup new feature mapping
if self.feature_mapping_new_old is None:
self.feature_mapping_new_old = {i: i for i in unique_features}
if self.feature_map_internal_original is None:
self.feature_map_internal_original = {i: i for i in unique_features}

def reduce_feature_complexity(self) -> None:
"""Reduces the feature complexity of the tree model.
Expand Down Expand Up @@ -119,8 +148,8 @@ def reduce_feature_complexity(self) -> None:
new_features[i] = new_value
self.features = new_features
self.feature_ids = new_feature_ids
self.feature_mapping_old_new = mapping_old_new
self.feature_mapping_new_old = mapping_new_old
self.feature_map_original_internal = mapping_old_new
self.feature_map_internal_original = mapping_new_old
self.n_features_in_tree = len(new_feature_ids)
self.max_feature_id = self.n_features_in_tree - 1

Expand All @@ -131,6 +160,8 @@ class EdgeTree:
The dataclass stores the information of an edge representation of the tree in a way that is easy
to access and manipulate for the TreeSHAP-IQ algorithm.
# TODO: add more information about the attributes
"""

parents: np.ndarray[int]
Expand All @@ -153,3 +184,40 @@ def __post_init__(self) -> None:
# setup has ancestors
if self.has_ancestors is None:
self.has_ancestors = self.ancestors > -1


def convert_tree_output_type(tree_model: TreeModel, output_type: str) -> tuple[TreeModel, bool]:
"""Convert the output type of the tree model.
Args:
tree_model: The tree model to convert.
output_type: The output type to convert the tree model to. Can be "raw", "probability", or
"logit".
Returns:
The converted tree model and a warning flag indicating whether invalid probability values
were adjusted in logit transformation.
"""
warning_flag = False
original_output_type = tree_model.original_output_type
if original_output_type == output_type or output_type == "raw": # no conversion needed
return tree_model, warning_flag
# transform probability to logit
if original_output_type == "probability" and output_type == "logit":
tree_model.values = np.log(tree_model.values / (1 - tree_model.values))
# give a warning if leaf values are replaced
if np.any(tree_model.values[tree_model.leaf_mask] == np.inf) or np.any(
tree_model.values[tree_model.leaf_mask] == -np.inf
):
warning_flag = True
# replace +inf with 14 and -inf with -14
tree_model.values = np.where(tree_model.values == np.inf, 14, tree_model.values)
tree_model.values = np.where(tree_model.values == -np.inf, -14, tree_model.values)
tree_model.compute_empty_prediction() # recompute the empty prediction
tree_model.original_output_type = output_type
# transform logit to probability
if original_output_type == "logit" and output_type == "probability":
tree_model.values = 1 / (1 + np.exp(-tree_model.values))
tree_model.compute_empty_prediction() # recompute the empty prediction
tree_model.original_output_type = output_type
return tree_model, warning_flag
20 changes: 5 additions & 15 deletions shapiq/explainer/tree/conversion/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,26 @@
def convert_sklearn_forest(
tree_model: Model,
class_label: Optional[int] = None,
output_type: str = "raw",
) -> list[TreeModel]:
"""Transforms a scikit-learn random forest to the format used by shapiq.
Args:
tree_model: The scikit-learn random forest model to convert.
class_label: The class label of the model to explain. Only used for classification models.
Defaults to 0.
output_type: Denotes if the tree output values should be transformed or not. Defaults
to None ('raw'). Possible values are 'raw', 'probability', and 'logits'.
Returns:
The converted random forest model.
"""
scaling = 1.0 / len(tree_model.estimators_)
return [
convert_sklearn_tree(
tree, scaling=scaling, class_label=class_label, output_type=output_type
)
convert_sklearn_tree(tree, scaling=scaling, class_label=class_label)
for tree in tree_model.estimators_
]


def convert_sklearn_tree(
tree_model: Model,
class_label: Optional[int] = None,
scaling: float = 1.0,
output_type: str = "raw",
tree_model: Model, class_label: Optional[int] = None, scaling: float = 1.0
) -> TreeModel:
"""Convert a scikit-learn decision tree to the format used by shapiq.
Expand All @@ -49,12 +41,11 @@ def convert_sklearn_tree(
class_label: The class label of the model to explain. Only used for classification models.
Defaults to 0.
scaling: The scaling factor for the tree values.
output_type: Denotes if the tree output values should be transformed or not. Defaults
to None ('raw'). Possible values are 'raw', 'probability', and 'logits'.
Returns:
The converted decision tree model.
"""
output_type = "raw"
tree_values = tree_model.tree_.value.copy() * scaling
# set class label if not given and model is a classifier
if safe_isinstance(tree_model, "sklearn.tree.DecisionTreeClassifier") and class_label is None:
Expand All @@ -66,9 +57,7 @@ def convert_sklearn_tree(
tree_values = tree_values[:, 0, :]
tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True)
tree_values = tree_values[:, class_label]
if output_type != "raw":
# TODO: Add support for logits output type
raise NotImplementedError("Only raw output types are currently supported.")
output_type = "probability"
tree_values = tree_values.flatten()
return TreeModel(
children_left=tree_model.tree_.children_left,
Expand All @@ -78,4 +67,5 @@ def convert_sklearn_tree(
values=tree_values,
node_sample_weight=tree_model.tree_.weighted_n_node_samples,
empty_prediction=None, # compute empty prediction later
original_output_type=output_type,
)
6 changes: 4 additions & 2 deletions shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from interaction_values import InteractionValues

from .treeshapiq import TreeModel, TreeSHAPIQ
from .validation import _validate_model
from .validation import validate_tree_model


class TreeExplainer(Explainer):
Expand All @@ -22,7 +22,9 @@ def __init__(
output_type: str = "raw",
) -> None:
# validate and parse model
validated_model = _validate_model(model, class_label=class_label, output_type=output_type)
validated_model = validate_tree_model(
model, class_label=class_label, output_type=output_type
)
self._trees: Union[TreeModel, list[TreeModel]] = copy.deepcopy(validated_model)
if not isinstance(self._trees, list):
self._trees = [self._trees]
Expand Down
83 changes: 72 additions & 11 deletions shapiq/explainer/tree/treeshapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,38 @@

from .base import EdgeTree, TreeModel
from .conversion.edges import create_edge_tree
from .validation import _validate_model
from .validation import validate_tree_model


class TreeSHAPIQ:
"""
The explainer for tree-based models using the TreeSHAP-IQ algorithm.
The explainer for tree-based models using the TreeSHAP-IQ algorithm. For a detailed presentation
of the algorithm, see the original paper: https://arxiv.org/abs/2401.12069.
TreeSHAP-IQ is an algorithm for computing Shapley Interaction values for tree-based models.
It is heavily based on the Linear TreeSHAP algorithm (outlined in https://proceedings.neurips.cc/paper_files/paper/2022/hash/a5a3b1ef79520b7cd122d888673a3ebc-Abstract-Conference.html)
but extended to compute Shapley Interaction values up to a given order. TreeSHAP-IQ needs to
visit each node only once and makes use of polynomial arithmetic to compute the Shapley
Interaction values efficiently.
Args:
model: The tree-based model to explain as an explainer.tree.conversion.TreeModel.
model: A single tree-based model to explain. Note unlike the TreeExplainer class,
TreeSHAP-IQ only supports a single tree model. The tree model can be a dictionary
representation of the tree, a `TreeModel` object, or any other tree model supported by
the `shapiq.explainer.tree.validation.validate_tree_model` function.
max_order: The maximum interaction order to be computed. An interaction order of 1
corresponds to the Shapley value. Any value higher than 1 computes the Shapley
interactions values up to that order. Defaults to 2.
interaction values up to that order. Defaults to 2.
min_order: The minimum interaction order to be computed. Defaults to 1.
interaction_type: The type of interaction to be computed. The interaction type can be
"k-SII" (default) or "SII", "STI", "FSI", "BZF".
"k-SII" (default), "SII", "STI", "FSI", or "BZF". All indices apart from "BZF" will
reduce to the "SV" (Shapley value) for order 1.
verbose: Whether to print information about the tree during initialization. Defaults to
False.
Note:
This class is not intended to be used directly. Instead, use the `TreeExplainer` class to
explain tree-based models which internally uses then the TreeSHAP-IQ algorithm.
"""

def __init__(
Expand All @@ -40,12 +57,17 @@ def __init__(
# set parameters
self._root_node_id = 0
self.verbose = verbose
if max_order < min_order or max_order < 1 or min_order < 1:
raise ValueError(
"The maximum order must be greater than the minimum order and both must be greater "
"than 0."
)
self._max_order: int = max_order
self._min_order: int = min_order
self._interaction_type: str = interaction_type

# validate and parse model
validated_model = _validate_model(model) # the parsed and validated model
validated_model = validate_tree_model(model) # the parsed and validated model
# TODO: add support for other sample weights
self._tree: TreeModel = copy.deepcopy(validated_model)
self._relevant_features: np.ndarray = np.array(list(self._tree.feature_ids), dtype=int)
Expand Down Expand Up @@ -151,6 +173,22 @@ def _compute_shapley_interaction_values(
quotient_poly_down: np.ndarray[float] = None,
depth: int = 0,
) -> None:
"""Computes the Shapley Interaction values for a given instance x_explain and interaction
order. This function is called recursively for each node in the tree.
Args:
x_explain: The instance to be explained.
order: The interaction order for which the Shapley Interaction values should be
computed. Defaults to 1.
node_id: The node ID of the current node in the tree. Defaults to 0.
summary_poly_down: The summary polynomial for the current node. Defaults to None (init).
summary_poly_up: The summary polynomial propagated up the tree. Defaults to None (init).
interaction_poly_down: The interaction polynomial for the current node. Defaults to
None (init).
quotient_poly_down: The quotient polynomial for the current node. Defaults to None
(init).
depth: The depth of the current node in the tree. Defaults to 0.
"""
# reset activations for new calculations
if node_id == 0:
self._activations.fill(False)
Expand Down Expand Up @@ -282,7 +320,7 @@ def _compute_shapley_interaction_values(
self.Ns_id[self.n_interpolation_size, : self.n_interpolation_size],
)
interaction_update *= self._psi(
summary_poly_up[depth, :], # TODO fix: wrong at this point (should not be zero)
summary_poly_up[depth, :],
D_power,
quotient_poly_down[depth, interactions_seen],
self.Ns,
Expand Down Expand Up @@ -326,7 +364,7 @@ def _compute_shapley_interaction_values(
self.shapley_interactions[interactions_with_ancestor_to_update] -= update

@staticmethod
def _psi(E, D_power, quotient_poly, Ns, degree):
def _psi(E, D_power, quotient_poly, Ns, degree) -> np.ndarray[float]:
# TODO: add docstring
d = degree + 1
n = Ns[d, :d]
Expand Down Expand Up @@ -355,7 +393,21 @@ def _get_polynomials(
summary_poly_up: np.ndarray[float] = None,
interaction_poly_down: np.ndarray[float] = None,
quotient_poly_down: np.ndarray[float] = None,
):
) -> tuple[np.ndarray[float], np.ndarray[float], np.ndarray[float], np.ndarray[float]]:
"""Retrieves the polynomials for a given interaction order. It initializes the polynomials
for the first call of the recursive explanation function.
Args:
order: The interaction order for which the polynomials should be loaded.
summary_poly_down: The summary polynomial for the current node. Defaults to None.
summary_poly_up: The summary polynomial propagated up the tree. Defaults to None.
interaction_poly_down: The interaction polynomial for the current node. Defaults to None.
quotient_poly_down: The quotient polynomial for the current node. Defaults to None.
Returns:
tuple: The summary polynomial down, the summary polynomial up, the interaction polynomial
down, and the quotient polynomial down.
"""
if summary_poly_down is None:
summary_poly_down = np.zeros((self._edge_tree.max_depth + 1, self.n_interpolation_size))
summary_poly_down[0, :] = 1
Expand All @@ -381,7 +433,7 @@ def _get_polynomials(
quotient_poly_down[0, :] = 1
return summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down

def _prepare_variables_for_order(self, interaction_order: int):
def _prepare_variables_for_order(self, interaction_order: int) -> None:
"""Retrieves the precomputed variables for a given interaction order. This function is
called before the recursive explanation function is called.
Expand Down Expand Up @@ -458,7 +510,16 @@ def _precalculate_interaction_ancestors(
self, interaction_order, n_features
) -> dict[int, np.ndarray]:
"""Calculates the position of the ancestors of the interactions for the tree for a given
order of interactions."""
order of interactions.
Args:
interaction_order: The interaction order for which the ancestors should be computed.
n_features: The number of features in the model.
Returns:
subset_ancestors: A dictionary containing the ancestors of the interactions for each
node in the tree.
"""

# stores position of interactions
counter_interaction = 0
Expand Down
Loading

0 comments on commit 6dfc249

Please sign in to comment.