Skip to content

Commit

Permalink
Add subset function to interaction values (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
r-visser authored Dec 18, 2024
1 parent b399dd2 commit fb47de6
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- makes abbreviations in the `plot` module optional [#281](https://github.com/mmschlk/shapiq/issues/281)
- adds the `upset_plot` function to the `plot` module to visualize the interactions of higher-order [#290](https://github.com/mmschlk/shapiq/issues/290)
- adds support for IsoForest models to explainer and tree explainer [#278](https://github.com/mmschlk/shapiq/issues/278)
- adds support for sub-selection of players in the interaction values data class [#276](https://github.com/mmschlk/shapiq/issues/276) which allows retrieving interaction values for a subset of players

### v1.1.1 (2024-11-13)

Expand Down
47 changes: 47 additions & 0 deletions shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,53 @@ def get_n_order(
baseline_value=self.baseline_value,
)

def get_subset(self, players: list[int]) -> "InteractionValues":
"""Selects a subset of players from the InteractionValues object.
Args:
players (list[int]): List of players to select from the InteractionValues object.
Returns:
InteractionValues: Filtered InteractionValues object containing only values related to
selected players.
Example:
>>> interaction_values = InteractionValues(
... values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=0.0,
... )
>>> interaction_values.get_subset([0, 1]).dict_values
{(0,): 0.1, (1,): 0.2, (0, 1): 0.3}
>>> interaction_values.get_subset([0, 2]).dict_values
{(0,): 0.1, (2,): 0.3, (0, 2): 0.4}
>>> interaction_values.get_subset([1]).dict_values
{(1,): 0.2}
"""
keys = self.interaction_lookup.keys()
idx = [i for i, key in enumerate(keys) if all(p in players for p in key)]
new_values = self.values[idx]
new_interaction_lookup = {
key: self.interaction_lookup[key] for i, key in enumerate(keys) if i in idx
}
n_players = self.n_players - len(players)

return InteractionValues(
values=new_values,
index=self.index,
max_order=self.max_order,
n_players=n_players,
min_order=self.min_order,
interaction_lookup=new_interaction_lookup,
estimated=self.estimated,
estimation_budget=self.estimation_budget,
baseline_value=self.baseline_value,
)

def save(self, path: str, as_pickle: bool = True) -> None:
"""Save the InteractionValues object to a file.
Expand Down
39 changes: 39 additions & 0 deletions tests/test_base_interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,42 @@ def test_plot():
_ = interaction_values.plot_network(feature_names=["a" for _ in range(n)])
_ = interaction_values.plot_stacked_bar()
_ = interaction_values.plot_stacked_bar(feature_names=["a" for _ in range(n)])


def test_subset():
n = 5
min_order = 1
max_order = 3
values = np.random.rand(2**n - 1)
interaction_lookup = {
interaction: i for i, interaction in enumerate(powerset(range(n), min_order, max_order))
}
interaction_values = InteractionValues(
values=values,
index=None,
max_order=max_order,
n_players=n,
min_order=min_order,
interaction_lookup=interaction_lookup,
estimated=False,
estimation_budget=0,
baseline_value=0.0,
)

subset_players = [0, 1, 2]
subset_interaction_values = interaction_values.get_subset(subset_players)

assert subset_interaction_values.n_players == n - len(subset_players)
assert all(
all(p in subset_players for p in key)
for key in subset_interaction_values.interaction_lookup.keys()
)
assert len(subset_interaction_values.values) == len(
subset_interaction_values.interaction_lookup
)
assert interaction_values.baseline_value == subset_interaction_values.baseline_value
assert subset_interaction_values.min_order == interaction_values.min_order
assert subset_interaction_values.max_order == interaction_values.max_order
assert subset_interaction_values.estimated == interaction_values.estimated
assert subset_interaction_values.estimation_budget == interaction_values.estimation_budget
assert subset_interaction_values.index == interaction_values.index

0 comments on commit fb47de6

Please sign in to comment.