diff --git a/shapiq/approximator/k_sii.py b/shapiq/approximator/k_sii.py index d795c68b..511cac52 100644 --- a/shapiq/approximator/k_sii.py +++ b/shapiq/approximator/k_sii.py @@ -112,8 +112,11 @@ def _calculate_ksii_from_sii( `min_order`, and `max_order` parameters. Defaults to `None`. Returns: - The nSII values. + The k-SII values. """ + if interaction_lookup is None: + interaction_lookup = generate_interaction_lookup(n, 1, max_order) + # compute nSII values from SII values bernoulli_numbers = bernoulli(max_order) nsii_values = np.zeros_like(sii_values) @@ -128,8 +131,11 @@ def _calculate_ksii_from_sii( # go over all subsets T of length |S| + 1, ..., n that contain S for T in powerset(set(range(n)), min_size=interaction_size + 1, max_size=max_order): if set(subset).issubset(T): - effect_index = interaction_lookup[T] # get the index of T - effect_value = sii_values[effect_index] # get the effect of T + try: + effect_index = interaction_lookup[T] # get the index of T + effect_value = sii_values[effect_index] # get the effect of T + except KeyError: + effect_value = 0 # if T is not in the interaction_lookup # TODO: verify this bernoulli_factor = bernoulli_numbers[len(T) - interaction_size] ksii_value += bernoulli_factor * effect_value nsii_values[interaction_index] = ksii_value diff --git a/shapiq/explainer/tree/explainer.py b/shapiq/explainer/tree/explainer.py index bfaf0ded..fb0aa772 100644 --- a/shapiq/explainer/tree/explainer.py +++ b/shapiq/explainer/tree/explainer.py @@ -17,6 +17,7 @@ def __init__( model: Union[dict, TreeModel, Any], max_order: int = 2, min_order: int = 1, + interaction_type: str = "k-SII", class_label: Optional[int] = None, output_type: str = "raw", ) -> None: @@ -34,7 +35,7 @@ def __init__( # setup explainers for all trees self._treeshapiq_explainers: list[TreeSHAPIQ] = [ - TreeSHAPIQ(model=_tree, max_order=self._max_order, interaction_type="SII") + TreeSHAPIQ(model=_tree, max_order=self._max_order, interaction_type=interaction_type) for _tree in self._trees ] diff --git a/tests/tests_approximators/test_approximator_ksii_estimation.py b/tests/tests_approximators/test_approximator_ksii_estimation.py index 61ee4c8a..e3cf0cd5 100644 --- a/tests/tests_approximators/test_approximator_ksii_estimation.py +++ b/tests/tests_approximators/test_approximator_ksii_estimation.py @@ -70,3 +70,7 @@ def test_nsii_estimation(sii_approximator, ksii_approximator): assert isinstance(transformed, np.ndarray) with pytest.raises(ValueError): _ = transforms_sii_to_ksii(sii_estimates.values) + # check with interaction_lookup = None + sii_estimates.interaction_lookup = None + transformed = transforms_sii_to_ksii(sii_estimates) + assert transformed.index == "k-SII"