Skip to content

Commit

Permalink
Merge pull request #65 from mmschlk/64-interaction_type-not-in-treeex…
Browse files Browse the repository at this point in the history
…plainer

adds interaction_type as parameter to TreeExplainer fixes #64 and closes #64
  • Loading branch information
mmschlk authored Mar 22, 2024
2 parents 87fa1ca + 4494843 commit d4f0d45
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
12 changes: 9 additions & 3 deletions shapiq/approximator/k_sii.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@ def _calculate_ksii_from_sii(
`min_order`, and `max_order` parameters. Defaults to `None`.
Returns:
The nSII values.
The k-SII values.
"""
if interaction_lookup is None:
interaction_lookup = generate_interaction_lookup(n, 1, max_order)

# compute nSII values from SII values
bernoulli_numbers = bernoulli(max_order)
nsii_values = np.zeros_like(sii_values)
Expand All @@ -128,8 +131,11 @@ def _calculate_ksii_from_sii(
# go over all subsets T of length |S| + 1, ..., n that contain S
for T in powerset(set(range(n)), min_size=interaction_size + 1, max_size=max_order):
if set(subset).issubset(T):
effect_index = interaction_lookup[T] # get the index of T
effect_value = sii_values[effect_index] # get the effect of T
try:
effect_index = interaction_lookup[T] # get the index of T
effect_value = sii_values[effect_index] # get the effect of T
except KeyError:
effect_value = 0 # if T is not in the interaction_lookup # TODO: verify this
bernoulli_factor = bernoulli_numbers[len(T) - interaction_size]
ksii_value += bernoulli_factor * effect_value
nsii_values[interaction_index] = ksii_value
Expand Down
3 changes: 2 additions & 1 deletion shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
model: Union[dict, TreeModel, Any],
max_order: int = 2,
min_order: int = 1,
interaction_type: str = "k-SII",
class_label: Optional[int] = None,
output_type: str = "raw",
) -> None:
Expand All @@ -34,7 +35,7 @@ def __init__(

# setup explainers for all trees
self._treeshapiq_explainers: list[TreeSHAPIQ] = [
TreeSHAPIQ(model=_tree, max_order=self._max_order, interaction_type="SII")
TreeSHAPIQ(model=_tree, max_order=self._max_order, interaction_type=interaction_type)
for _tree in self._trees
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,7 @@ def test_nsii_estimation(sii_approximator, ksii_approximator):
assert isinstance(transformed, np.ndarray)
with pytest.raises(ValueError):
_ = transforms_sii_to_ksii(sii_estimates.values)
# check with interaction_lookup = None
sii_estimates.interaction_lookup = None
transformed = transforms_sii_to_ksii(sii_estimates)
assert transformed.index == "k-SII"

0 comments on commit d4f0d45

Please sign in to comment.