Skip to content

Commit

Permalink
Tabpfn explainer (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk authored Jan 15, 2025
1 parent 4c4b560 commit bd7597f
Show file tree
Hide file tree
Showing 25 changed files with 1,067 additions and 179 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
## Changelog

### v1.1.2 (2025-01-13)
- adds ``shapiq.TabPFNExplainer`` as a specialized version of the ``shapiq.TabularExplainer`` which offers a streamlined variant of the explainer for the TabPFN model [#301](https://github.com/mmschlk/shapiq/issues/301)
- handles ``explainer.explain()`` now through a common interface for all explainer classes which now need to implement a ``explain_function()`` method
- adds the baseline_value into the InteractionValues object's value storage for the ``()`` interaction if ``min_order=0`` (default usually) for all indices that are not ``SII```(SII has another baseline value) such that the values are efficient (sum up to the model prediction) without the awkward handling of the baseline_value attribute
- renames ``game_fun`` parameter in ``shapiq.ExactComputer`` to ``game`` [#297](https://github.com/mmschlk/shapiq/issues/297)
- adds a TabPFN example notebook to the documentation
- removes warning when class_index is not provided in explainers [#298](https://github.com/mmschlk/shapiq/issues/298)
Expand Down
492 changes: 383 additions & 109 deletions docs/source/notebooks/tabular_notebooks/explaining_tabpfn.ipynb

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ xgboost==2.1.3
numpy==1.26.4
requests==2.32.3
lightgbm==4.5.0
tabpfn==2.0.3; python_version <= '3.11'
8 changes: 5 additions & 3 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
the well established Shapley value and its generalization to interaction.
"""

__version__ = "1.1.2"
__version__ = "1.2.0"

# approximator classes
from .approximator import (
Expand Down Expand Up @@ -39,14 +39,14 @@
from .datasets import load_adult_census, load_bike_sharing, load_california_housing

# explainer classes
from .explainer import Explainer, TabularExplainer, TreeExplainer
from .explainer import Explainer, TabPFNExplainer, TabularExplainer, TreeExplainer

# exact computer classes
from .game_theory.exact import ExactComputer

# game classes
# imputer classes
from .games import BaselineImputer, ConditionalImputer, Game, MarginalImputer
from .games import BaselineImputer, ConditionalImputer, Game, MarginalImputer, TabPFNImputer

# base classes
from .interaction_values import InteractionValues
Expand Down Expand Up @@ -97,10 +97,12 @@
"Explainer",
"TabularExplainer",
"TreeExplainer",
"TabPFNExplainer",
# imputers
"MarginalImputer",
"BaselineImputer",
"ConditionalImputer",
"TabPFNImputer",
# plots
"network_plot",
"stacked_bar_plot",
Expand Down
3 changes: 2 additions & 1 deletion shapiq/explainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Explainer objects, including TreeSHAP-IQ."""

from ._base import Explainer
from .tabpfn import TabPFNExplainer
from .tabular import TabularExplainer
from .tree import TreeExplainer

__all__ = ["Explainer", "TabularExplainer", "TreeExplainer"]
__all__ = ["Explainer", "TabularExplainer", "TreeExplainer", "TabPFNExplainer"]
86 changes: 67 additions & 19 deletions shapiq/explainer/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""The base Explainer classes for the shapiq package."""

from abc import abstractmethod
from typing import Optional
from warnings import warn

import numpy as np

Expand All @@ -12,7 +14,10 @@ class Explainer:
"""The main Explainer class for a simpler user interface.
shapiq.Explainer is a simplified interface for the ``shapiq`` package. It detects between
TabularExplainer and TreeExplainer based on the model class.
:class:`~shapiq.explainer.tabular.TabularExplainer`,
:class:`~shapiq.explainer.tree.TreeExplainer`,
and :class:`~shapiq.explainer.tabpfn.TabPFNExplainer`. For a detailed description of the
different explainers, see the respective classes.
Args:
model: The model object to be explained.
Expand All @@ -32,24 +37,14 @@ def __init__(
) -> None:

self._model_class = print_class(model)
self._predict_function, self._model_type = get_predict_function_and_model_type(
self._shapiq_predict_function, self._model_type = get_predict_function_and_model_type(
model, self._model_class, class_index
)
self.model = model

if data is not None:
if not isinstance(data, np.ndarray):
raise TypeError("`data` must be a NumPy array.")
try:
pred = self.predict(data)
if isinstance(pred, np.ndarray):
if len(pred.shape) > 1:
raise ValueError()
else:
raise ValueError()
except Exception as e:
print(f"Error: The `data` provided is not compatible with the model. {e}")
pass
if self._model_type != "tabpfn":
self._validate_data(data)
self.data = data

# not super()
Expand All @@ -59,13 +54,66 @@ def __init__(
self.__class__ = _explainer
_explainer.__init__(self, model=model, data=data, class_index=class_index, **kwargs)

def explain(self, x: np.ndarray) -> InteractionValues:
"""Explain the model's prediction in terms of interaction values.
def _validate_data(self, data: np.ndarray, raise_error: bool = False) -> None:
"""Validate the data for compatibility with the model.
Args:
x: An instance/point/sample/observation to be explained.
data: A 2-dimensional matrix of inputs to be explained.
raise_error: Whether to raise an error if the data is not compatible with the model or
only print a warning. Defaults to ``False``.
Raises:
TypeError: If the data is not a NumPy array.
"""
message = "The `data` and the model must be compatible."
if not isinstance(data, np.ndarray):
message += " The `data` must be a NumPy array."
raise TypeError(message)
try:
# TODO (mmschlk): This can take a long time for large datasets and slow models
pred = self.predict(data)
if isinstance(pred, np.ndarray):
if len(pred.shape) > 1:
message += " The model's prediction must be a 1-dimensional array."
raise ValueError()
else:
message += " The model's prediction must be a NumPy array."
raise ValueError()
except Exception as e:
if raise_error:
raise ValueError(message) from e
else:
warn(message)

def explain(self, x: np.ndarray, *args, **kwargs) -> InteractionValues:
"""Explain a single prediction in terms of interaction values.
Args:
x: A numpy array of a data point to be explained.
*args: Additional positional arguments passed to the explainer.
**kwargs: Additional keyword-only arguments passed to the explainer.
Returns:
The interaction values of the prediction.
"""
explanation = self.explain_function(x=x, *args, **kwargs)
if explanation.min_order == 0:
explanation[()] = explanation.baseline_value
return explanation

@abstractmethod
def explain_function(self, x: np.ndarray, *args, **kwargs) -> InteractionValues:
"""Explain a single prediction in terms of interaction values.
Args:
x: A numpy array of a data point to be explained.
*args: Additional positional arguments passed to the explainer.
**kwargs: Additional keyword-only arguments passed to the explainer.
Returns:
The interaction values of the prediction.
"""
return {}
raise NotImplementedError("The method `explain` must be implemented in a subclass.")

def explain_X(
self, X: np.ndarray, n_jobs=None, random_state=None, **kwargs
Expand Down Expand Up @@ -104,4 +152,4 @@ def predict(self, x: np.ndarray) -> np.ndarray:
Args:
x: An instance/point/sample/observation to be explained.
"""
return self._predict_function(self.model, x)
return self._shapiq_predict_function(self.model, x)
120 changes: 120 additions & 0 deletions shapiq/explainer/tabpfn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""This module contains the TabPFNExplainer class, which is a class for explaining the predictions
of a TabPFN model."""

from typing import Optional, Union

import numpy as np

from ..approximator._base import Approximator
from .tabular import TabularExplainer
from .utils import ModelType, get_predict_function_and_model_type


class TabPFNExplainer(TabularExplainer):
"""The TabPFN explainer as the main interface for the shapiq package.
The ``TabPFNExplainer`` class is the dedicated interface for the ``shapiq`` package and
TabPFN[2]_ models such as the ``TabPFNClassifier`` and ``TabPFNRegressor``. The explainer
does not rely on classical imputation methods and is optimized for TabPFN's in-context learning
approach. The explanation paradigm for TabPFN is described in Runel et al. (2024)[1]_. In
essence the explainer is a wrapper around the ``TabularExplainer`` class and uses the same API.
Args:
model: Either a TabPFNClassifier or TabPFNRegressor model to be explained.
data: The background data to use for the explainer as a 2-dimensional array with shape
``(n_samples, n_features)``. This data is used to contextualize the model on.
labels: The labels for the background data as a 1-dimensional array with shape
``(n_samples,)``. This data is used to contextualize the model on.
index: The index to explain the model with. Defaults to ``"k-SII"`` which computes the
k-Shapley Interaction Index. If ``max_order`` is set to 1, this corresponds to the
Shapley value (``index="SV"``). Options are:
- ``"SV"``: Shapley value
- ``"k-SII"``: k-Shapley Interaction Index
- ``"FSII"``: Faithful Shapley Interaction Index
- ``"STII"``: Shapley Taylor Interaction Index
- ``"SII"``: Shapley Interaction Index (not recommended for XAI since the values do
not sum up to the prediction)
x_test: An optional test data set to compute the model's empty prediction (average
prediction) on. If no test data and ``empty_prediction`` is set to ``None`` the last
20% of the background data is used as test data and the remaining 80% as training data
for contextualization. Defaults to ``None``.
empty_prediction: Optional value for the model's average prediction on an empty data point
(all features missing). If provided, overrides parameters in ``x_test``. and skips the
computation of the empty prediction. Defaults to ``None``.
class_index: The class index of the model to explain. Defaults to ``None``, which will set
the class index to ``1`` per default for classification models and is ignored for
regression models.
approximator: The approximator to use for calculating the Shapley values or Shapley
interactions. Can be a string or an instance of an approximator. Defaults to ``"auto"``.
verbose: Whether to show a progress bar during the computation. Defaults to ``False``.
Note that verbosity can slow down the computation for large datasets.
References:
.. [1] Rundel, D., Kobialka, J., von Crailsheim, C., Feurer, M., Nagler, T., Rügamer, D. (2024). Interpretable Machine Learning for TabPFN. In: Longo, L., Lapuschkin, S., Seifert, C. (eds) Explainable Artificial Intelligence. xAI 2024. Communications in Computer and Information Science, vol 2154. Springer, Cham. https://doi.org/10.1007/978-3-031-63797-1_23
.. [2] Hollmann, N., Müller, S., Purucker, L. et al. Accurate predictions on small data with a tabular foundation model. Nature 637, 319–326 (2025). https://doi.org/10.1038/s41586-024-08328-6
"""

def __init__(
self,
*,
model: ModelType,
data: np.ndarray,
labels: np.ndarray,
index: str = "k-SII",
max_order: int = 2,
x_test: Optional[np.ndarray] = None,
empty_prediction: Optional[float] = None,
class_index: Optional[int] = None,
approximator: Union[str, Approximator] = "auto",
verbose: bool = False,
):
from ..games.imputer.tabpfn_imputer import TabPFNImputer

_predict_function, _ = get_predict_function_and_model_type(model, class_index=class_index)
model._shapiq_predict_function = _predict_function

# check that data and labels have the same number of samples
if data.shape[0] != labels.shape[0]:
raise ValueError(
f"The number of samples in `data` and `labels` must be equal (got data.shape= "
f"{data.shape} and labels.shape={labels.shape})."
)
n_samples = data.shape[0]
x_train = data
y_train = labels

if x_test is None and empty_prediction is None:
sections = [int(0.8 * n_samples)]
x_train, x_test = np.split(data, sections)
y_train, _ = np.split(labels, sections)

if x_test is None:
x_test = x_train # is not used in the TabPFNImputer if empty_prediction is set

imputer = TabPFNImputer(
model=model,
x_train=x_train,
y_train=y_train,
x_test=x_test,
empty_prediction=empty_prediction,
verbose=verbose,
)

super().__init__(
model,
data=x_test,
imputer=imputer,
class_index=class_index,
approximator=approximator,
index=index,
max_order=max_order,
)
Loading

0 comments on commit bd7597f

Please sign in to comment.