Skip to content

Commit

Permalink
Merge pull request #103 from mmschlk/80-add-tabular-game-with-example…
Browse files Browse the repository at this point in the history
…-datasets

The California Housing Benchmark Game
  • Loading branch information
mmschlk authored Apr 5, 2024
2 parents c3d3d3a + 26e0fac commit d25809e
Show file tree
Hide file tree
Showing 14 changed files with 20,973 additions and 20,699 deletions.
41,282 changes: 20,641 additions & 20,641 deletions data/california.csv

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions shapiq/games/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
from .dummy import DummyGame
from .imputer import MarginalImputer
from .sentiment_language import SentimentClassificationGame
from .tabular import CaliforniaHousing, FeatureSelectionGame, LocalExplanation

__all__ = [
"DummyGame",
"Game",
"MarginalImputer",
"SentimentClassificationGame",
"LocalExplanation",
"FeatureSelectionGame",
"CaliforniaHousing",
]
13 changes: 9 additions & 4 deletions shapiq/games/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class Game(ABC):
value for the empty coalition is zero. Defaults to `None`. If `normalization` is set
to `False` this value is not required. Otherwise, the value is needed to normalize and
center the game. If no value is provided, the game raise a warning.
path_to_values: The path to load the game values from. If the path is provided, the game
values are loaded from the given path. Defaults to `None`.
Note:
This class is an abstract base class and should not be instantiated directly. All games
Expand All @@ -31,16 +33,18 @@ class Game(ABC):
@abstractmethod
def __init__(
self,
n_players: int,
n_players: Optional[int] = None,
normalize: bool = True,
normalization_value: Optional[float] = None,
path_to_values: Optional[str] = None,
) -> None:
# define storage variables
self.value_storage: np.ndarray = np.zeros(0, dtype=float)
self.coalition_lookup: dict[tuple[int, ...], int] = {}
self.n_players: int = n_players # if path_to_values is provided, this will be overwritten

# define some handy variables describing the game
self.n_players: int = n_players
if path_to_values is not None:
self.load_values(path_to_values, precomputed=True)

# setup normalization of the game
self.normalization_value: float = 0.0
Expand Down Expand Up @@ -226,11 +230,12 @@ def load_values(self, path: str, precomputed: bool = False) -> None:

data = np.load(path)
n_players = data["n_players"]
if n_players != self.n_players:
if self.n_players is not None and n_players != self.n_players:
raise ValueError(
f"The number of players in the game ({self.n_players}) does not match the number "
f"of players in the saved game ({n_players})."
)
self.n_players = n_players
self.value_storage = data["values"]
self.coalition_lookup = transform_array_to_coalitions(data["coalitions"])
self.precompute_flag = precomputed
Expand Down
4 changes: 2 additions & 2 deletions shapiq/games/imputer/marginal_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
if normalize:
self.normalization_value = self.empty_prediction

def value_function(self, coalitions: np.ndarray) -> np.ndarray:
def value_function(self, coalitions: np.ndarray[bool]) -> np.ndarray[float]:
"""Imputes the missing values of a data point and calls the model.
Args:
Expand Down Expand Up @@ -118,7 +118,7 @@ def init_background(self, x_background: np.ndarray) -> "MarginalImputer":
self.replacement_data[:, feature] = summarized_feature
return self

def fit(self, x_explain: np.ndarray[float]) -> "MarginalImputer":
def fit(self, x_explain: np.ndarray) -> "MarginalImputer":
"""Fits the imputer to the explanation point.
Args:
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit d25809e

Please sign in to comment.