Skip to content

Commit

Permalink
Merge pull request #95 from mmschlk/68-add-unbiased-kernelshap-approx…
Browse files Browse the repository at this point in the history
…imator

68 add unbiased kernelshap approximator
  • Loading branch information
mmschlk authored Apr 3, 2024
2 parents 284a498 + 0bdcefe commit ea817fb
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 4 deletions.
3 changes: 2 additions & 1 deletion shapiq/approximator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .permutation.sii import PermutationSamplingSII
from .permutation.sti import PermutationSamplingSTI
from .regression import KernelSHAP, RegressionFSI, RegressionSII
from .shapiq import ShapIQ
from .shapiq import ShapIQ, UnbiasedKernelSHAP

__all__ = [
"PermutationSamplingSII",
Expand All @@ -13,6 +13,7 @@
"RegressionFSI",
"RegressionSII",
"ShapIQ",
"UnbiasedKernelSHAP",
"transforms_sii_to_ksii",
"convert_ksii_into_one_dimension",
]
3 changes: 2 additions & 1 deletion shapiq/approximator/shapiq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains the shapiq estimator to approximate all cardinal interaction indices."""

from .shapiq import ShapIQ
from .unbiased_kernelshap import UnbiasedKernelSHAP

__all__ = ["ShapIQ"]
__all__ = ["ShapIQ", "UnbiasedKernelSHAP"]
4 changes: 2 additions & 2 deletions shapiq/approximator/shapiq/shapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from shapiq.interaction_values import InteractionValues
from shapiq.utils import powerset

AVAILABLE_INDICES_SHAPIQ = {"SII", "STI", "FSI", "k-SII"}
AVAILABLE_INDICES_SHAPIQ = {"SII", "STI", "FSI", "k-SII", "SV"}


class ShapIQ(Approximator, ShapleySamplingMixin, KShapleyMixin):
Expand Down Expand Up @@ -229,7 +229,7 @@ def _weight_kernel(self, subset_size: int, interaction_size: int) -> float:
Returns:
float: The weight for the interaction type.
"""
if self.index == "SII" or self.index == "k-SII": # in both cases return SII kernel
if self.index == "SII" or self.index == "k-SII" or self.index == "SV": # SII kernel default
return self._sii_weight_kernel(subset_size, interaction_size)
elif self.index == "STI":
return self._sti_weight_kernel(subset_size, interaction_size)
Expand Down
49 changes: 49 additions & 0 deletions shapiq/approximator/shapiq/unbiased_kernelshap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""This module contains the Unbiased KernelSHAP approximation method for the Shapley value (SV).
The Unbiased KernelSHAP method is a variant of the KernelSHAP. However, it was shown that
Unbiased KernelSHAP is a more specific variant of the ShapIQ interaction method.
"""
from typing import Optional

from .shapiq import ShapIQ


class UnbiasedKernelSHAP(ShapIQ):

"""The Unbiased KernelSHAP approximator for estimating the Shapley value (SV).
The Unbiased KernelSHAP estimator is a variant of the KernelSHAP estimator (though deeply
different). Unbiased KernelSHAP was proposed in Covert and Lee's
[original paper](http://proceedings.mlr.press/v130/covert21a/covert21a.pdf) as an unbiased
version of KernelSHAP. Recently, in Fumagalli et al.'s
[paper](https://proceedings.neurips.cc/paper_files/paper/2023/hash/264f2e10479c9370972847e96107db7f-Abstract-Conference.html),
it was shown that Unbiased KernelSHAP is a more specific variant of the ShapIQ approximation
method (Theorem 4.5).
Args:
n: The number of players.
random_state: The random state of the estimator. Defaults to `None`.
Example:
>>> from shapiq.games import DummyGame
>>> from shapiq.approximator import UnbiasedKernelSHAP
>>> game = DummyGame(n=5, interaction=(1, 2))
>>> approximator = UnbiasedKernelSHAP(n=5)
>>> approximator.approximate(budget=100, game=game)
InteractionValues(
index=SV, order=1, estimated=False, estimation_budget=32,
values={
(0,): 0.2,
(1,): 0.7,
(2,): 0.7,
(3,): 0.2,
(4,): 0.2,
}
)
"""

def __init__(
self,
n: int,
random_state: Optional[int] = None,
):
super().__init__(n, 1, "SV", False, random_state)
56 changes: 56 additions & 0 deletions tests/tests_approximators/test_approximator_unbiased_ksh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""This test module contains all tests for the Unbiased KernelSHAP approximator."""

import copy

import pytest

from shapiq.approximator import UnbiasedKernelSHAP
from shapiq.games import DummyGame


def test_basic_functionality():
"""Tests the initialization of the RegressionFSI approximator."""
n_players = 7

approximator = UnbiasedKernelSHAP(n_players)
assert approximator.n == n_players
assert approximator.max_order == 1
assert approximator.top_order is False
assert approximator.min_order == 1
assert approximator.iteration_cost == 1
assert approximator.index == "SV"

approximator_copy = copy.copy(approximator)
approximator_deepcopy = copy.deepcopy(approximator)
approximator_deepcopy.index = "something"
assert approximator_copy == approximator # check that the copy is equal
assert approximator_deepcopy != approximator # check that the deepcopy is not equal
approximator_string = str(approximator)
assert repr(approximator) == approximator_string
assert hash(approximator) == hash(approximator_copy)
assert hash(approximator) != hash(approximator_deepcopy)

# test that the approximator can approximate the correct values
interaction = (1, 2)
game = DummyGame(n_players, interaction)
budget = 2**n_players

approximator = UnbiasedKernelSHAP(n_players)
sv_estimates = approximator.approximate(budget, game)
assert sv_estimates.n_players == n_players
assert sv_estimates.max_order == 1
assert sv_estimates.min_order == 1
assert sv_estimates.index == "SV"
assert sv_estimates.estimated is False
assert sv_estimates.estimation_budget == budget

# check that the values are correct
assert sv_estimates[()] == 0.0
assert sv_estimates[(0,)] == pytest.approx(0.1429, 0.001)
assert sv_estimates[(1,)] == pytest.approx(0.6429, 0.001)

# smaller budget
budget = int(budget * 0.75)
sv_estimates = approximator.approximate(budget, game)
assert sv_estimates[(0,)] == pytest.approx(0.1429, abs=0.2)
assert sv_estimates[(1,)] == pytest.approx(0.6429, abs=0.2)

0 comments on commit ea817fb

Please sign in to comment.