Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors shapiq.games.benchmark into seperate shapiq.benchmark module. #175

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### development

- refactored the `shapiq.games.benchmark` module into a separate `shapiq.benchmark` module by moving all but the benchmark games into the new modul. This closes [#169](https://github.com/mmschlk/shapiq/issues/169) and makes benchmarking more flexible and convenient.
- add a legend to benchmark plots [#170](https://github.com/mmschlk/shapiq/issues/170)
- fix the force plot not showing and its baseline value
- improve tests for plots and benchmarks
Expand Down
10 changes: 5 additions & 5 deletions docs/source/notebooks/benchmark_approximators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
}
},
"source": [
"from shapiq.games.benchmark.benchmark_config import (\n",
"from shapiq.benchmark import (\n",
" load_games_from_configuration,\n",
" print_benchmark_configurations,\n",
")"
Expand Down Expand Up @@ -569,7 +569,7 @@
},
"source": [
"# run the benchmark\n",
"from shapiq.games.benchmark.run import run_benchmark\n",
"from shapiq.benchmark import run_benchmark\n",
"\n",
"results = run_benchmark(\n",
" index=index,\n",
Expand Down Expand Up @@ -658,7 +658,7 @@
},
"source": [
"# plot the results\n",
"from shapiq.games.benchmark.plot import plot_approximation_quality\n",
"from shapiq.benchmark import plot_approximation_quality\n",
"\n",
"fig, axis = plot_approximation_quality(results)"
],
Expand Down Expand Up @@ -803,7 +803,7 @@
},
"source": [
"# run the benchmark\n",
"from shapiq.games.benchmark.run import run_benchmark\n",
"from shapiq.benchmark import run_benchmark\n",
"\n",
"results = run_benchmark(\n",
" index=index,\n",
Expand Down Expand Up @@ -892,7 +892,7 @@
},
"source": [
"# plot the results\n",
"from shapiq.games.benchmark.plot import plot_approximation_quality\n",
"from shapiq.benchmark import plot_approximation_quality\n",
"\n",
"# colors in the plot: \"KernelSHAPIQ\": orange, \"SVARMIQ\": blue, \"PermutationSamplingSII\": purple\n",
"plot_approximation_quality(results, log_scale_y=True)"
Expand Down
39 changes: 36 additions & 3 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
the well established Shapley value and its generalization to interaction.
"""

__version__ = "1.0.1.9000"
__version__ = "1.0.1.9001"

# approximator classes
from .approximator import (
Expand All @@ -21,18 +21,37 @@
UnbiasedKernelSHAP,
kADDSHAP,
)
from .benchmark import (
BENCHMARK_CONFIGURATIONS,
GAME_CLASS_TO_NAME_MAPPING,
GAME_NAME_TO_CLASS_MAPPING,
download_game_data,
load_benchmark_results,
load_game_data,
load_games_from_configuration,
plot_approximation_quality,
print_benchmark_configurations,
run_benchmark,
run_benchmark_from_configuration,
)

# dataset functions
from .datasets import load_adult_census, load_bike_sharing, load_california_housing

# exact computer classes
from .exact import ExactComputer

# explainer classes
from .explainer import Explainer, TabularExplainer, TreeExplainer
from .games import ConditionalImputer, Game, MarginalImputer

# game classes
from .games import ConditionalImputer, Game, MarginalImputer

# base classes
from .interaction_values import InteractionValues

# plotting functions
from .plot import force_plot, network_plot, stacked_bar_plot
from .plot import bar_plot, force_plot, network_plot, si_graph_plot, stacked_bar_plot

# public utils functions
from .utils import ( # sets.py # tree.py
Expand Down Expand Up @@ -75,6 +94,8 @@
"network_plot",
"stacked_bar_plot",
"force_plot",
"bar_plot",
"si_graph_plot",
# public utils
"powerset",
"get_explicit_subsets",
Expand All @@ -84,4 +105,16 @@
"load_bike_sharing",
"load_adult_census",
"load_california_housing",
# benchmark
"plot_approximation_quality",
"run_benchmark",
"run_benchmark_from_configuration",
"load_benchmark_results",
"print_benchmark_configurations",
"BENCHMARK_CONFIGURATIONS",
"GAME_CLASS_TO_NAME_MAPPING",
"GAME_NAME_TO_CLASS_MAPPING",
"load_games_from_configuration",
"download_game_data",
"load_game_data",
]
7 changes: 3 additions & 4 deletions shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
import numpy as np

from shapiq.approximator.sampling import CoalitionSampler
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import generate_interaction_lookup

from ..indices import (
from shapiq.indices import (
AVAILABLE_INDICES_FOR_APPROXIMATION,
get_computation_index,
is_empty_value_the_baseline,
is_index_aggregated,
)
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import generate_interaction_lookup

__all__ = [
"Approximator",
Expand Down
1 change: 1 addition & 0 deletions shapiq/benchmark/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
precomputed
31 changes: 31 additions & 0 deletions shapiq/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""This module contains all functions for conducting benchmarks with the SHAPIQ package."""

from .configuration import (
BENCHMARK_CONFIGURATIONS,
GAME_CLASS_TO_NAME_MAPPING,
GAME_NAME_TO_CLASS_MAPPING,
print_benchmark_configurations,
)
from .load import download_game_data, load_game_data, load_games_from_configuration
from .plot import plot_approximation_quality
from .run import load_benchmark_results, run_benchmark, run_benchmark_from_configuration

__all__ = [
# # configuration
"print_benchmark_configurations",
"BENCHMARK_CONFIGURATIONS",
"GAME_CLASS_TO_NAME_MAPPING",
"GAME_NAME_TO_CLASS_MAPPING",
# # loading
"load_games_from_configuration",
"download_game_data",
"load_game_data",
# # running benchmarks
"run_benchmark_from_configuration",
"run_benchmark",
"load_benchmark_results",
# plotting benchmark results
"plot_approximation_quality",
]

# Path: shapiq/benchmark/__init__.py
Loading
Loading