-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #95 from mmschlk/68-add-unbiased-kernelshap-approx…
…imator 68 add unbiased kernelshap approximator
- Loading branch information
Showing
5 changed files
with
111 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
"""This module contains the shapiq estimator to approximate all cardinal interaction indices.""" | ||
|
||
from .shapiq import ShapIQ | ||
from .unbiased_kernelshap import UnbiasedKernelSHAP | ||
|
||
__all__ = ["ShapIQ"] | ||
__all__ = ["ShapIQ", "UnbiasedKernelSHAP"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""This module contains the Unbiased KernelSHAP approximation method for the Shapley value (SV). | ||
The Unbiased KernelSHAP method is a variant of the KernelSHAP. However, it was shown that | ||
Unbiased KernelSHAP is a more specific variant of the ShapIQ interaction method. | ||
""" | ||
from typing import Optional | ||
|
||
from .shapiq import ShapIQ | ||
|
||
|
||
class UnbiasedKernelSHAP(ShapIQ): | ||
|
||
"""The Unbiased KernelSHAP approximator for estimating the Shapley value (SV). | ||
The Unbiased KernelSHAP estimator is a variant of the KernelSHAP estimator (though deeply | ||
different). Unbiased KernelSHAP was proposed in Covert and Lee's | ||
[original paper](http://proceedings.mlr.press/v130/covert21a/covert21a.pdf) as an unbiased | ||
version of KernelSHAP. Recently, in Fumagalli et al.'s | ||
[paper](https://proceedings.neurips.cc/paper_files/paper/2023/hash/264f2e10479c9370972847e96107db7f-Abstract-Conference.html), | ||
it was shown that Unbiased KernelSHAP is a more specific variant of the ShapIQ approximation | ||
method (Theorem 4.5). | ||
Args: | ||
n: The number of players. | ||
random_state: The random state of the estimator. Defaults to `None`. | ||
Example: | ||
>>> from shapiq.games import DummyGame | ||
>>> from shapiq.approximator import UnbiasedKernelSHAP | ||
>>> game = DummyGame(n=5, interaction=(1, 2)) | ||
>>> approximator = UnbiasedKernelSHAP(n=5) | ||
>>> approximator.approximate(budget=100, game=game) | ||
InteractionValues( | ||
index=SV, order=1, estimated=False, estimation_budget=32, | ||
values={ | ||
(0,): 0.2, | ||
(1,): 0.7, | ||
(2,): 0.7, | ||
(3,): 0.2, | ||
(4,): 0.2, | ||
} | ||
) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n: int, | ||
random_state: Optional[int] = None, | ||
): | ||
super().__init__(n, 1, "SV", False, random_state) |
56 changes: 56 additions & 0 deletions
56
tests/tests_approximators/test_approximator_unbiased_ksh.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""This test module contains all tests for the Unbiased KernelSHAP approximator.""" | ||
|
||
import copy | ||
|
||
import pytest | ||
|
||
from shapiq.approximator import UnbiasedKernelSHAP | ||
from shapiq.games import DummyGame | ||
|
||
|
||
def test_basic_functionality(): | ||
"""Tests the initialization of the RegressionFSI approximator.""" | ||
n_players = 7 | ||
|
||
approximator = UnbiasedKernelSHAP(n_players) | ||
assert approximator.n == n_players | ||
assert approximator.max_order == 1 | ||
assert approximator.top_order is False | ||
assert approximator.min_order == 1 | ||
assert approximator.iteration_cost == 1 | ||
assert approximator.index == "SV" | ||
|
||
approximator_copy = copy.copy(approximator) | ||
approximator_deepcopy = copy.deepcopy(approximator) | ||
approximator_deepcopy.index = "something" | ||
assert approximator_copy == approximator # check that the copy is equal | ||
assert approximator_deepcopy != approximator # check that the deepcopy is not equal | ||
approximator_string = str(approximator) | ||
assert repr(approximator) == approximator_string | ||
assert hash(approximator) == hash(approximator_copy) | ||
assert hash(approximator) != hash(approximator_deepcopy) | ||
|
||
# test that the approximator can approximate the correct values | ||
interaction = (1, 2) | ||
game = DummyGame(n_players, interaction) | ||
budget = 2**n_players | ||
|
||
approximator = UnbiasedKernelSHAP(n_players) | ||
sv_estimates = approximator.approximate(budget, game) | ||
assert sv_estimates.n_players == n_players | ||
assert sv_estimates.max_order == 1 | ||
assert sv_estimates.min_order == 1 | ||
assert sv_estimates.index == "SV" | ||
assert sv_estimates.estimated is False | ||
assert sv_estimates.estimation_budget == budget | ||
|
||
# check that the values are correct | ||
assert sv_estimates[()] == 0.0 | ||
assert sv_estimates[(0,)] == pytest.approx(0.1429, 0.001) | ||
assert sv_estimates[(1,)] == pytest.approx(0.6429, 0.001) | ||
|
||
# smaller budget | ||
budget = int(budget * 0.75) | ||
sv_estimates = approximator.approximate(budget, game) | ||
assert sv_estimates[(0,)] == pytest.approx(0.1429, abs=0.2) | ||
assert sv_estimates[(1,)] == pytest.approx(0.6429, abs=0.2) |