Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/github_actions/pypa/gh-action-pyp…
Browse files Browse the repository at this point in the history
…i-publish-1.12.4
  • Loading branch information
mmschlk authored Feb 5, 2025
2 parents 2369434 + 0526f90 commit a8295be
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 21 deletions.
18 changes: 9 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
# this requirements.txt file is maintained and bumbped by dependabot
# this file is used to see weather new versions of the libraries are available and break shapiq
black==24.10.0
black==25.1.0
colour==0.1.5
coverage==7.6.10
matplotlib==3.10.0
networkx==3.4.2
pandas==2.2.3
pytest==8.3.4
ruff==0.8.4
scikit-image==0.25.0
scikit-learn==1.6.0
scipy==1.14.1
ruff==0.9.4
scikit-image==0.25.1
scikit-learn==1.6.1
scipy==1.15.1
tqdm==4.67.1
torch==2.5.1
torchvision==0.20.1
transformers==4.47.1
torch==2.6.0
torchvision==0.21.0
transformers==4.48.2
tensorflow==2.18.0
tf-keras==2.18.0
xgboost==2.1.3
numpy==1.26.4
requests==2.32.3
lightgbm==4.5.0
tabpfn==2.0.3; python_version <= '3.11'
tabpfn==2.0.5
8 changes: 7 additions & 1 deletion shapiq/explainer/tree/treeshapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,13 @@ def _init_summary_polynomials(self):
interaction_order=order, n_features=self._n_features_in_tree
)
self.subset_ancestors_store[order] = subset_ancestors
self.D_store[order] = np.polynomial.chebyshev.chebpts2(self.n_interpolation_size)

# If the tree has only one feature, we assign a default value of 0
if self.n_interpolation_size == 1:
self.D_store[order] = np.array([0])
else:
self.D_store[order] = np.polynomial.chebyshev.chebpts2(self.n_interpolation_size)

self.D_powers_store[order] = self._cache(self.D_store[order])
if self._index in ("SV", "SII", "k-SII"):
self.Ns_store[order] = self._get_N(self.D_store[order])
Expand Down
2 changes: 1 addition & 1 deletion tests/requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ requests==2.32.3
lightgbm==4.5.0
tf-keras==2.18.0
tensorflow==2.18.0
tabpfn==2.0.3; python_version <= '3.11'
tabpfn==2.0.5
4 changes: 0 additions & 4 deletions tests/tests_explainer/test_tabpfn_explainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""This test module tests the TabPFNExplainer object."""

import sys

import pytest

from shapiq import Explainer, InteractionValues, TabPFNExplainer, TabularExplainer


@pytest.mark.skipif(sys.version_info > (3, 11), reason="requires python3.11 or lower")
def test_tabpfn_explainer_clf(tabpfn_classification_problem):
"""Test the TabPFNExplainer class for classification problems."""
import tabpfn
Expand All @@ -34,7 +31,6 @@ def test_tabpfn_explainer_clf(tabpfn_classification_problem):
assert isinstance(explainer, TabularExplainer)


@pytest.mark.skipif(sys.version_info > (3, 11), reason="requires python3.11 or lower")
def test_tabpfn_explainer_reg(tabpfn_regression_problem):
"""Test the TabPFNExplainer class for regression problems."""
import tabpfn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from shapiq.explainer.tree import TreeModel, TreeSHAPIQ
from shapiq.explainer.tree import TreeExplainer, TreeModel, TreeSHAPIQ


def test_init(dt_clf_model, background_clf_data):
Expand Down Expand Up @@ -132,3 +132,29 @@ def test_edge_case_params():
# test with max_order = 0
with pytest.raises(ValueError):
_ = TreeSHAPIQ(model=tree_model, max_order=0)


def test_no_bug_with_one_feature_tree():
# create the dataset
X = np.array(
[
[1, 1, 1, 1],
[1, 1, 1, 2],
[2, 1, 1, 1],
[3, 2, 1, 1],
]
)

# Define simple one feature tree
tree = {
"children_left": np.array([1, -1, -1]),
"children_right": np.array([2, -1, -1]),
"features": np.array([0, -2, -2]),
"thresholds": np.array([2.5, -2, -2]),
"values": np.array([0.5, 0.0, 1]),
"node_sample_weight": np.array([14, 5, 9]),
}
tree = TreeModel(**tree)
explainer = TreeExplainer(model=tree, index="SV", max_order=1)
shapley_values = explainer.explain(X[2])
print(shapley_values)
5 changes: 0 additions & 5 deletions tests/tests_imputer/test_tabpfn_imputer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
"""This test module tests the tabpfn imputer object."""

import sys

import numpy as np
import pytest

from shapiq import TabPFNImputer
from shapiq.explainer.utils import get_predict_function_and_model_type


@pytest.mark.skipif(sys.version_info > (3, 11), reason="requires python3.11 or lower")
def test_tabpfn_imputer(tabpfn_classification_problem):
"""Test the TabPFNImputer class."""
import tabpfn
Expand Down Expand Up @@ -42,7 +39,6 @@ def test_tabpfn_imputer(tabpfn_classification_problem):
assert model.n_features_in_ == 1


@pytest.mark.skipif(sys.version_info > (3, 11), reason="requires python3.11 or lower")
def test_empty_prediction(tabpfn_classification_problem):
"""Tests the TabPFNImputer with a manual empty prediction."""
import tabpfn
Expand Down Expand Up @@ -72,7 +68,6 @@ def test_empty_prediction(tabpfn_classification_problem):
assert output[0] == manual_empty_prediction


@pytest.mark.skipif(sys.version_info > (3, 11), reason="requires python3.11 or lower")
def test_tabpfn_imputer_validation(tabpfn_classification_problem):
"""Test that the TabPFNImputer raises a ValueError if no predict function is provided."""
import tabpfn
Expand Down

0 comments on commit a8295be

Please sign in to comment.