From 5da63048038c5dfae8a69e8bec9bf744d8f19e26 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Thu, 2 Jan 2025 17:07:40 +0100 Subject: [PATCH 01/21] bar_plot without shap --- shapiq/plot/bar.py | 206 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 197 insertions(+), 9 deletions(-) diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index af7f5163..4dfc1ea9 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -1,17 +1,212 @@ """Wrapper for the bar plot from the ``shap`` package.""" +import re from typing import Optional import matplotlib.pyplot as plt import numpy as np from ..interaction_values import InteractionValues -from ..utils.modules import check_import_module +from ._config import BLUE, RED from .utils import get_interaction_values_and_feature_names __all__ = ["bar_plot"] +def format_value(s, format_str): + """Strips trailing zeros and uses a unicode minus sign.""" + if not issubclass(type(s), str): + s = format_str % s + s = re.sub(r"\.?0+$", "", s) + if s[0] == "-": + s = "\u2212" + s[1:] + return s + + +def _bar(values, feature_names, max_display=10, ax=None, show=True): + """Create a bar plot of a set of SHAP values. + + Parameters + ---------- + shap_values : shap.Explanation or shap.Cohorts or dictionary of shap.Explanation objects + Passing a multi-row :class:`.Explanation` object creates a global + feature importance plot. + + Passing a single row of an explanation (i.e. ``shap_values[0]``) creates + a local feature importance plot. + + Passing a dictionary of Explanation objects will create a multiple-bar + plot with one bar type for each of the cohorts represented by the + explanation objects. + max_display : int + How many top features to include in the bar plot (default is 10). + order : OpChain or numpy.ndarray + A function that returns a sort ordering given a matrix of SHAP values + and an axis, or a direct sample ordering given as a ``numpy.ndarray``. + + By default, take the absolute value. + clustering: np.ndarray or None + A partition tree, as returned by ``shap.utils.hclust`` + clustering_cutoff: float + Controls how much of the clustering structure is displayed. + show_data: bool or str + Controls if data values are shown as part of the y tick labels. If + "auto", we show the data only when there are no transforms. + ax: matplotlib Axes + Axes object to draw the plot onto, otherwise uses the current Axes. + show : bool + Whether ``matplotlib.pyplot.show()`` is called before returning. + Setting this to ``False`` allows the plot + to be customized further after it has been created. + + Returns + ------- + ax: matplotlib Axes + Returns the Axes object with the plot drawn onto it. Only returned if ``show=False``. + + Examples + -------- + See `bar plot examples `_. + + """ + # assert str(type(shap_values)).endswith("Explanation'>"), "The shap_values parameter must be a shap.Explanation object!" + + # ensure we at least have default feature names + if feature_names is None: + feature_names = np.array([f"Feature {i}" for i in range(len(values[0]))]) + if issubclass(type(feature_names), str): + feature_names = [i + " " + feature_names for i in range(len(values[0]))] + + # build our auto xlabel based on the transform history of the Explanation object + xlabel = "SHAP value" + + # determine how many top features we will plot + if max_display is None: + max_display = len(feature_names) + num_features = min(max_display, len(values[0])) + max_display = min(max_display, num_features) + + # Make it descending order + feature_order = np.argsort(values)[0][::-1] + + y_pos = np.arange(len(feature_order), 0, -1) + + # build our y-tick labels + yticklabels = [] + for i in feature_order: + yticklabels.append(feature_names[i]) + + if ax is None: + ax = plt.gca() + # Only modify the figure size if ax was not passed in + # compute our figure size based on how many features we are showing + fig = plt.gcf() + row_height = 0.5 + fig.set_size_inches(8, num_features * row_height * np.sqrt(len(values)) + 1.5) + + # if negative values are present then we draw a vertical line to mark 0, otherwise the axis does this for us... + negative_values_present = np.sum(values[:, feature_order[:num_features]] < 0) > 0 + if negative_values_present: + ax.axvline(0, 0, 1, color="#000000", linestyle="-", linewidth=1, zorder=1) + + # draw the bars + patterns = (None, "\\\\", "++", "xx", "////", "*", "o", "O", ".", "-") + total_width = 0.7 + bar_width = total_width / len(values) + for i in range(len(values)): + ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2) + ax.barh( + y_pos + ypos_offset, + values[i, feature_order], + bar_width, + align="center", + color=[ + BLUE.hex if values[i, feature_order[j]] <= 0 else RED.hex for j in range(len(y_pos)) + ], + hatch=patterns[i], + edgecolor=(1, 1, 1, 0.8), + label="", + ) + + # draw the yticks (the 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks) + ax.set_yticks( + list(y_pos) + list(y_pos + 1e-8), + yticklabels + [t.split("=")[-1] for t in yticklabels], + fontsize=13, + ) + + xlen = ax.get_xlim()[1] - ax.get_xlim()[0] + # xticks = ax.get_xticks() + bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) + width = bbox.width + bbox_to_xscale = xlen / width + + for i in range(len(values)): + ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2) + for j in range(len(y_pos)): + ind = feature_order[j] + if values[i, ind] < 0: + ax.text( + values[i, ind] - (5 / 72) * bbox_to_xscale, + y_pos[j] + ypos_offset, + format_value(values[i, ind], "%+0.02f"), + horizontalalignment="right", + verticalalignment="center", + color=BLUE.hex, + fontsize=12, + ) + else: + ax.text( + values[i, ind] + (5 / 72) * bbox_to_xscale, + y_pos[j] + ypos_offset, + format_value(values[i, ind], "%+0.02f"), + horizontalalignment="left", + verticalalignment="center", + color=RED.hex, + fontsize=12, + ) + + # put horizontal lines for each feature row + for i in range(num_features): + ax.axhline(i + 1, color="#888888", lw=0.5, dashes=(1, 5), zorder=-1) + + ax.xaxis.set_ticks_position("bottom") + ax.yaxis.set_ticks_position("none") + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + if negative_values_present: + ax.spines["left"].set_visible(False) + ax.tick_params("x", labelsize=11) + + xmin, xmax = ax.get_xlim() + ymin, ymax = ax.get_ylim() + x_buffer = (xmax - xmin) * 0.05 + + if negative_values_present: + ax.set_xlim(xmin - x_buffer, xmax + x_buffer) + else: + ax.set_xlim(xmin, xmax + x_buffer) + + # if features is None: + # pl.xlabel(labels["GLOBAL_VALUE"], fontsize=13) + # else: + ax.set_xlabel(xlabel, fontsize=13) + + if len(values) > 1: + ax.legend(fontsize=12) + + # color the y tick labels that have the feature values as gray + # (these fall behind the black ones with just the feature name) + tick_labels = ax.yaxis.get_majorticklabels() + for i in range(num_features): + tick_labels[i].set_color("#999999") + + if show: + plt.show() + else: + return ax + + def bar_plot( list_of_interaction_values: list[InteractionValues], feature_names: Optional[np.ndarray] = None, @@ -32,8 +227,6 @@ def bar_plot( abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. **kwargs: Keyword arguments passed to ``shap.plots.beeswarm()``. """ - check_import_module("shap") - import shap assert len(np.unique([iv.max_order for iv in list_of_interaction_values])) == 1 @@ -53,13 +246,8 @@ def bar_plot( _base_values.append(iv.baseline_value) _labels = np.array(_labels) if feature_names is not None else None - explanation = shap.Explanation( - values=np.stack(_global_values), - base_values=np.array(_base_values), - feature_names=_labels, - ) - ax = shap.plots.bar(explanation, **kwargs, show=False) + ax = _bar(values=np.stack(_global_values), feature_names=_labels, show=False) ax.set_xlabel("mean(|Shapley Interaction value|)") if not show: return ax From 6de11758ee6d48ad5ac0cf0062f964cbde465cf5 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sat, 4 Jan 2025 13:18:15 +0100 Subject: [PATCH 02/21] bar_plot without shap --- shapiq/plot/bar.py | 46 +++++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index 4dfc1ea9..fc4d18b1 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -8,7 +8,6 @@ from ..interaction_values import InteractionValues from ._config import BLUE, RED -from .utils import get_interaction_values_and_feature_names __all__ = ["bar_plot"] @@ -77,8 +76,7 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): if issubclass(type(feature_names), str): feature_names = [i + " " + feature_names for i in range(len(values[0]))] - # build our auto xlabel based on the transform history of the Explanation object - xlabel = "SHAP value" + xlabel = "Shapley value" # determine how many top features we will plot if max_display is None: @@ -92,9 +90,7 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): y_pos = np.arange(len(feature_order), 0, -1) # build our y-tick labels - yticklabels = [] - for i in feature_order: - yticklabels.append(feature_names[i]) + yticklabels = [feature_names[i] for i in feature_order] if ax is None: ax = plt.gca() @@ -125,7 +121,7 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ], hatch=patterns[i], edgecolor=(1, 1, 1, 0.8), - label="", + label="Model " + str(i), ) # draw the yticks (the 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks) @@ -193,7 +189,7 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ax.set_xlabel(xlabel, fontsize=13) if len(values) > 1: - ax.legend(fontsize=12) + ax.legend(fontsize=12, loc="lower right") # color the y tick labels that have the feature values as gray # (these fall behind the black ones with just the feature name) @@ -207,6 +203,15 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): return ax +def default_feature_name(feature_tuple): + if len(feature_tuple) == 0: + return "Basevalue" + elif len(feature_tuple) == 1: + return "Feature " + str(feature_tuple[0]) + else: + return " x ".join([str(f) for f in feature_tuple]) + + def bar_plot( list_of_interaction_values: list[InteractionValues], feature_names: Optional[np.ndarray] = None, @@ -230,25 +235,16 @@ def bar_plot( assert len(np.unique([iv.max_order for iv in list_of_interaction_values])) == 1 - _global_values = [] - _base_values = [] - _labels = [] - _first_iv = True - for iv in list_of_interaction_values: + values = np.stack([iv.values for iv in list_of_interaction_values]) - _shap_values, _names = get_interaction_values_and_feature_names( - iv, feature_names, None, abbreviate=abbreviate - ) - if _first_iv: - _labels = _names - _first_iv = False - _global_values.append(_shap_values) - _base_values.append(iv.baseline_value) - - _labels = np.array(_labels) if feature_names is not None else None + labels = ( + np.array(list(map(default_feature_name, list_of_interaction_values[0].dict_values.keys()))) + if feature_names is None + else feature_names + ) - ax = _bar(values=np.stack(_global_values), feature_names=_labels, show=False) - ax.set_xlabel("mean(|Shapley Interaction value|)") + ax = _bar(values=values, feature_names=labels, show=False) + ax.set_xlabel("Shapley value") if not show: return ax plt.show() From 360a4f13185a5596cb4d874e018d04f4e69750b0 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 11:08:26 +0100 Subject: [PATCH 03/21] forceplot without shap --- shapiq/plot/force.py | 587 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 551 insertions(+), 36 deletions(-) diff --git a/shapiq/plot/force.py b/shapiq/plot/force.py index 9907282a..de6e83c7 100644 --- a/shapiq/plot/force.py +++ b/shapiq/plot/force.py @@ -2,51 +2,566 @@ from typing import Optional +import matplotlib import matplotlib.pyplot as plt import numpy as np +from matplotlib import lines +from matplotlib.font_manager import FontProperties +from matplotlib.patches import PathPatch +from matplotlib.path import Path from ..interaction_values import InteractionValues -from ..utils.modules import check_import_module -from .utils import get_interaction_values_and_feature_names __all__ = ["force_plot"] -def force_plot( - interaction_values: InteractionValues, - feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, - matplotlib: bool = True, - show: bool = False, - abbreviate: bool = True, - **kwargs, -) -> Optional[plt.Figure]: - """Draws interaction values on a force plot. +def _create_bars( + out_value: float, + features: np.ndarray, + feature_type: str, + width_separators: float, + width_bar: float, +) -> tuple[list, list]: + """ + Create bars and separators for the plot. + Args: + out_value: the output value + features: names and values of the features to add + feature_type: Indicating whether positive or negative features + width_separators: width to separate the bars + width_bar: width of the bars + + Returns: List of bars and separators + """ + rectangle_list = [] + separator_list = [] + + pre_val = out_value + for index, features in zip(range(len(features)), features): + if feature_type == "positive": + left_bound = float(features[0]) + right_bound = pre_val + pre_val = left_bound + + separator_indent = np.abs(width_separators) + separator_pos = left_bound + colors = ["#FF0D57", "#FFC3D5"] + else: + left_bound = pre_val + right_bound = float(features[0]) + pre_val = right_bound + + separator_indent = -np.abs(width_separators) + separator_pos = right_bound + colors = ["#1E88E5", "#D1E6FA"] + + # Create rectangle + if index == 0: + if feature_type == "positive": + points_rectangle = [ + [left_bound, 0], + [right_bound, 0], + [right_bound, width_bar], + [left_bound, width_bar], + [left_bound + separator_indent, (width_bar / 2)], + ] + else: + points_rectangle = [ + [right_bound, 0], + [left_bound, 0], + [left_bound, width_bar], + [right_bound, width_bar], + [right_bound + separator_indent, (width_bar / 2)], + ] + + else: + points_rectangle = [ + [left_bound, 0], + [right_bound, 0], + [right_bound + separator_indent * 0.90, (width_bar / 2)], + [right_bound, width_bar], + [left_bound, width_bar], + [left_bound + separator_indent * 0.90, (width_bar / 2)], + ] + + line = plt.Polygon( + points_rectangle, closed=True, fill=True, facecolor=colors[0], linewidth=0 + ) + rectangle_list += [line] + + # Create separator + points_separator = [ + [separator_pos, 0], + [separator_pos + separator_indent, (width_bar / 2)], + [separator_pos, width_bar], + ] + + line = plt.Polygon(points_separator, closed=None, fill=None, edgecolor=colors[1], lw=3) + separator_list += [line] + + return rectangle_list, separator_list + + +def _add_labels( + fig: plt.Figure, + ax: plt.Axes, + out_value: float, + features: np.ndarray, + feature_type: str, + offset_text: float, + total_effect: float = 0, + min_perc: float = 0.05, + text_rotation: float = 0, +) -> None: + """ + Add labels to the plot. + Args: + fig: Figure of the plot + ax: Axes of the plot + out_value: output value + features: The values and names of the features + feature_type: Indicating whether positive or negative features + offset_text: value to offset name of the features + total_effect: Total value of all features. Used to filter out features that do not contribute at least min_perc to the total effect. + Defaults to 0 indicating that all features are shown. + min_perc: minimal percentage of the total effect that a feature must contribute to be shown. Defaults to 0.05. + text_rotation: Degree the text should be rotated. Defaults to 0. + + Returns: + + """ + start_text = out_value + pre_val = out_value + + # Define variables specific to positive and negative effect features + if feature_type == "positive": + colors = ["#FF0D57", "#FFC3D5"] + alignment = "right" + sign = 1 + else: + colors = ["#1E88E5", "#D1E6FA"] + alignment = "left" + sign = -1 + + # Draw initial line + if feature_type == "positive": + x, y = np.array([[pre_val, pre_val], [0, -0.18]]) + line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0]) + line.set_clip_on(False) + ax.add_line(line) + start_text = pre_val + + box_end = out_value + val = out_value + for feature in features: + # Exclude all labels that do not contribute at least 10% to the total + feature_contribution = np.abs(float(feature[0]) - pre_val) / np.abs(total_effect) + if feature_contribution < min_perc: + break + + # Compute value for current feature + val = float(feature[0]) + + # Draw labels. + text = feature[1] + + if text_rotation != 0: + va_alignment = "top" + else: + va_alignment = "baseline" + + text_out_val = plt.text( + start_text - sign * offset_text, + -0.15, + text, + fontsize=12, + color=colors[0], + horizontalalignment=alignment, + va=va_alignment, + rotation=text_rotation, + ) + text_out_val.set_bbox(dict(facecolor="none", edgecolor="none")) + + # We need to draw the plot to be able to get the size of the + # text box + fig.canvas.draw() + box_size = text_out_val.get_bbox_patch().get_extents().transformed(ax.transData.inverted()) + if feature_type == "positive": + box_end_ = box_size.get_points()[0][0] + else: + box_end_ = box_size.get_points()[1][0] + + # Create end line + if (sign * box_end_) > (sign * val): + x, y = np.array([[val, val], [0, -0.18]]) + line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0]) + line.set_clip_on(False) + ax.add_line(line) + start_text = val + box_end = val + + else: + box_end = box_end_ - sign * offset_text + x, y = np.array([[val, box_end, box_end], [0, -0.08, -0.18]]) + line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0]) + line.set_clip_on(False) + ax.add_line(line) + start_text = box_end + + # Update previous value + pre_val = float(feature[0]) + + # Create line for labels + extent_shading = [out_value, box_end, 0, -0.31] + path = [ + [out_value, 0], + [pre_val, 0], + [box_end, -0.08], + [box_end, -0.2], + [out_value, -0.2], + [out_value, 0], + ] + + path = Path(path) + patch = PathPatch(path, facecolor="none", edgecolor="none") + ax.add_patch(patch) + + # Extend axis if needed + lower_lim, upper_lim = ax.get_xlim() + if box_end < lower_lim: + ax.set_xlim(box_end, upper_lim) + + if box_end > upper_lim: + ax.set_xlim(lower_lim, box_end) + + # Create shading + if feature_type == "positive": + colors = np.array([(255, 13, 87), (255, 255, 255)]) / 255.0 + else: + colors = np.array([(30, 136, 229), (255, 255, 255)]) / 255.0 + + cm = matplotlib.colors.LinearSegmentedColormap.from_list("cm", colors) + + _, Z2 = np.meshgrid(np.linspace(0, 10), np.linspace(-10, 10)) + im = plt.imshow( + Z2, + interpolation="quadric", + cmap=cm, + vmax=0.01, + alpha=0.3, + origin="lower", + extent=extent_shading, + clip_path=patch, + clip_on=True, + aspect="auto", + ) + im.set_clip_path(patch) + + return fig, ax + + +def _add_output_element(out_name: str, out_value: float, ax: plt.Axes) -> None: + """ + Add grew line indicating the output value to the plot. + Args: + out_name: Name of the output value + out_value: Value of the output + ax: Axis of the plot + + Returns: Nothing + + """ + # Add output value + x, y = np.array([[out_value, out_value], [0, 0.24]]) + line = lines.Line2D(x, y, lw=2.0, color="#F2F2F2") + line.set_clip_on(False) + ax.add_line(line) + + font0 = FontProperties() + font = font0.copy() + font.set_weight("bold") + text_out_val = plt.text( + out_value, + 0.25, + f"{out_value:.2f}", + fontproperties=font, + fontsize=14, + horizontalalignment="center", + ) + text_out_val.set_bbox(dict(facecolor="white", edgecolor="white")) + + text_out_val = plt.text( + out_value, 0.33, out_name, fontsize=12, alpha=0.5, horizontalalignment="center" + ) + text_out_val.set_bbox(dict(facecolor="white", edgecolor="white")) + + +def _add_base_value(base_value: float, ax: plt.Axes) -> None: + """ + Add base value to the plot. + Args: + base_value: the base value of the game + ax: Axes of the plot + + Returns: None + + """ + x, y = np.array([[base_value, base_value], [0.13, 0.25]]) + line = lines.Line2D(x, y, lw=2.0, color="#F2F2F2") + line.set_clip_on(False) + ax.add_line(line) + + text_out_val = ax.text( + base_value, 0.33, "base value", fontsize=12, alpha=1, horizontalalignment="center" + ) + text_out_val.set_bbox(dict(facecolor="white", edgecolor="white")) + + +def draw_higher_lower_element(out_value, offset_text): + plt.text( + out_value - offset_text, + 0.405, + "higher", + fontsize=13, + color="#FF0D57", + horizontalalignment="right", + ) + + plt.text( + out_value + offset_text, + 0.405, + "lower", + fontsize=13, + color="#1E88E5", + horizontalalignment="left", + ) + + plt.text( + out_value, 0.4, r"$\leftarrow$", fontsize=13, color="#1E88E5", horizontalalignment="center" + ) + + plt.text( + out_value, + 0.425, + r"$\rightarrow$", + fontsize=13, + color="#FF0D57", + horizontalalignment="center", + ) - Requires the ``shap`` Python package to be installed. +def update_axis_limits( + ax: plt.Axes, + total_pos: float, + pos_features: np.ndarray, + total_neg: float, + neg_features: np.ndarray, + base_value: float, + out_value: float, +) -> None: + """ + Adjust the axis limits of the plot according to values. + Args: + ax: Axes of the plot + total_pos: value of the total positive features + pos_features: values and names of the positive features + total_neg: value of the total negative features + neg_features: values and names of the negative features + base_value: the base value of the game + out_value: the output value + + Returns: None + + """ + ax.set_ylim(-0.5, 0.15) + padding = np.max([np.abs(total_pos) * 0.2, np.abs(total_neg) * 0.2]) + + if len(pos_features) > 0: + min_x = min(np.min(pos_features[:, 0].astype(float)), base_value) - padding + else: + min_x = out_value - padding + if len(neg_features) > 0: + max_x = max(np.max(neg_features[:, 0].astype(float)), base_value) + padding + else: + max_x = out_value + padding + ax.set_xlim(min_x, max_x) + + plt.tick_params( + top=True, + bottom=False, + left=False, + right=False, + labelleft=False, + labeltop=True, + labelbottom=False, + ) + plt.locator_params(axis="x", nbins=12) + + for key, spine in zip(plt.gca().spines.keys(), plt.gca().spines.values()): + if key != "top": + spine.set_visible(False) + + +def _draw_force_plot( + interaction_value: InteractionValues, + feature_names: np.ndarray, + figsize: tuple[int, int], + show: bool = True, + text_rotation: float = 0, + min_perc: float = 0.05, +): + """ + Draw the force plot. Args: - interaction_values: The interaction values as an interaction object. - feature_names: The feature names used for plotting. If no feature names are provided, the - feature indices are used instead. Defaults to ``None``. - feature_values: The feature values used for plotting. Defaults to ``None``. - matplotlib: Whether to return a ``matplotlib`` figure. Defaults to ``True``. - show: Whether to show the plot. Defaults to ``False``. - abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. - **kwargs: Keyword arguments passed to ``shap.plots.force()``. - """ - check_import_module("shap") - import shap - - _shap_values, _labels = get_interaction_values_and_feature_names( - interaction_values, feature_names, feature_values, abbreviate=abbreviate - ) - - return shap.plots.force( - base_value=np.array([interaction_values.baseline_value], dtype=float), # must be array - shap_values=np.array(_shap_values), - feature_names=_labels, - matplotlib=matplotlib, - show=show, - **kwargs, + interaction_value: Interactiovalues ot be plotted + feature_names: names of the features + figsize: size of the figure + show: Whether to show the plot + text_rotation: Amount of text rotation + min_perc: Define the minimum percentage of the total effect that a feature must contribute to be shown. + Defaults to 0.05. + + Returns: None + + """ + # Turn off interactive plot + if show is False: + plt.ioff() + + # Compute overall metrics + base_value = interaction_value.baseline_value + out_value = np.sum(interaction_value.values) # TODO: Must be the value of the grand coalition + # Format data + feature_to_names = {i: name for i, name in enumerate(feature_names)} + dict_values = interaction_value.dict_values + pos_features = np.array( + sorted( + [ + [str(values), " x ".join([feature_to_names[f] for f in features])] + for features, values in dict_values.items() + if values >= 0 and len(features) > 0 + ], + key=lambda x: x[0], + reverse=True, + ), + dtype=object, + ) + neg_features = np.array( + sorted( + [ + [str(values), " x ".join([feature_to_names[f] for f in features])] + for features, values in dict_values.items() + if values < 0 and len(features) > 0 + ], + key=lambda x: x[0], + reverse=True, + ), + dtype=object, + ) + + # Convert negative feature values to plot values + neg_val = out_value + for i in neg_features: + val = float(i[0]) + neg_val = neg_val + np.abs(val) + i[0] = neg_val + if len(neg_features) > 0: + total_neg = np.max(neg_features[:, 0].astype(float)) - np.min( + neg_features[:, 0].astype(float) + ) + else: + total_neg = 0 + + # Convert positive feature values to plot values + pos_val = out_value + for i in pos_features: + val = float(i[0]) + pos_val = pos_val - np.abs(val) + i[0] = pos_val + + if len(pos_features) > 0: + total_pos = np.max(pos_features[:, 0].astype(float)) - np.min( + pos_features[:, 0].astype(float) + ) + else: + total_pos = 0 + + # Define plots + offset_text = (np.abs(total_neg) + np.abs(total_pos)) * 0.04 + + fig, ax = plt.subplots(figsize=figsize) + + # Compute axis limit + update_axis_limits(ax, total_pos, pos_features, total_neg, neg_features, base_value, out_value) + + # Define width of bar + width_bar = 0.1 + width_separators = (ax.get_xlim()[1] - ax.get_xlim()[0]) / 200 + + # Create bar for negative shap values + rectangle_list, separator_list = _create_bars( + out_value, neg_features, "negative", width_separators, width_bar + ) + for i in rectangle_list: + ax.add_patch(i) + + for i in separator_list: + ax.add_patch(i) + + # Create bar for positive shap values + rectangle_list, separator_list = _create_bars( + out_value, pos_features, "positive", width_separators, width_bar ) + for i in rectangle_list: + ax.add_patch(i) + + for i in separator_list: + ax.add_patch(i) + + # Add labels + total_effect = np.abs(total_neg) + total_pos + fig, ax = _add_labels( + fig, + ax, + out_value, + neg_features, + "negative", + offset_text, + total_effect, + min_perc=min_perc, + text_rotation=text_rotation, + ) + + fig, ax = _add_labels( + fig, + ax, + out_value, + pos_features, + "positive", + offset_text, + total_effect, + min_perc=min_perc, + text_rotation=text_rotation, + ) + + # Add label for base value + _add_base_value(base_value, ax) + + # Add output label + out_names = "" + _add_output_element(out_names, out_value, ax) + + if show: + plt.show() + else: + return plt.gcf() + + +def force_plot( + interaction_values: InteractionValues, + feature_names: Optional[np.ndarray] = None, + show: bool = False, +): + if feature_names is None: + feature_names = np.array([str(i) for i in range(interaction_values.n_players)]) + return _draw_force_plot(interaction_values, feature_names, figsize=(20, 3), show=show) From 40e5334648e624e6e5a668584bae08c508cd8195 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 11:38:00 +0100 Subject: [PATCH 04/21] Made waterfall plot without shap. --- shapiq/plot/watefall.py | 358 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 326 insertions(+), 32 deletions(-) diff --git a/shapiq/plot/watefall.py b/shapiq/plot/watefall.py index 0af4b941..704baf94 100644 --- a/shapiq/plot/watefall.py +++ b/shapiq/plot/watefall.py @@ -2,58 +2,352 @@ from typing import Optional +import matplotlib import matplotlib.pyplot as plt import numpy as np from ..interaction_values import InteractionValues -from ..utils.modules import check_import_module -from .utils import get_interaction_values_and_feature_names +from ._config import BLUE, RED +from .utils import format_value __all__ = ["waterfall_plot"] +def _draw_waterfall_plot( + values: np.ndarray, base_values: float, feature_names: list[str], max_display=10, show=True +) -> Optional[plt.Axes]: + """ + Create a waterfall plot idential to SHAP waterfall plot (https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/waterfall.html). + Args: + values: the explanation values + base_values: the base value of the game + feature_names: the names of the features + max_display: the maximum number of features to display + show: whether to show the plot + + Returns: the plot if show is False + + """ + # Turn off interactive plot + if show is False: + plt.ioff() + + # init variables we use for tracking the plot locations + num_features = min(max_display, len(values)) + row_height = 0.5 + rng = range(num_features - 1, -1, -1) + order = np.argsort(-np.abs(values)) + pos_lefts = [] + pos_inds = [] + pos_widths = [] + pos_low = [] + pos_high = [] + neg_lefts = [] + neg_inds = [] + neg_widths = [] + neg_low = [] + neg_high = [] + loc = base_values + values.sum() + yticklabels = ["" for _ in range(num_features + 1)] + + # size the plot based on how many features we are plotting + plt.gcf().set_size_inches(8, num_features * row_height + 1.5) + + # see how many individual (vs. grouped at the end) features we are plotting + if num_features == len(values): + num_individual = num_features + else: + num_individual = num_features - 1 + + # compute the locations of the individual features and plot the dashed connecting lines + for i in range(num_individual): + sval = values[order[i]] + loc -= sval + if sval >= 0: + pos_inds.append(rng[i]) + pos_widths.append(sval) + pos_lefts.append(loc) + else: + neg_inds.append(rng[i]) + neg_widths.append(sval) + neg_lefts.append(loc) + if num_individual != num_features or i + 4 < num_individual: + plt.plot( + [loc, loc], + [rng[i] - 1 - 0.4, rng[i] + 0.4], + color="#bbbbbb", + linestyle="--", + linewidth=0.5, + zorder=-1, + ) + yticklabels[rng[i]] = feature_names[order[i]] + + # add a last grouped feature to represent the impact of all the features we didn't show + if num_features < len(values): + yticklabels[0] = "%d other features" % (len(values) - num_features + 1) + remaining_impact = base_values - loc + if remaining_impact < 0: + pos_inds.append(0) + pos_widths.append(-remaining_impact) + pos_lefts.append(loc + remaining_impact) + else: + neg_inds.append(0) + neg_widths.append(-remaining_impact) + neg_lefts.append(loc + remaining_impact) + + points = ( + pos_lefts + + list(np.array(pos_lefts) + np.array(pos_widths)) + + neg_lefts + + list(np.array(neg_lefts) + np.array(neg_widths)) + ) + dataw = np.max(points) - np.min(points) + + # draw invisible bars just for sizing the axes + label_padding = np.array([0.1 * dataw if w < 1 else 0 for w in pos_widths]) + plt.barh( + pos_inds, + np.array(pos_widths) + label_padding + 0.02 * dataw, + left=np.array(pos_lefts) - 0.01 * dataw, + color=RED.hex, + alpha=0, + ) + label_padding = np.array([-0.1 * dataw if -w < 1 else 0 for w in neg_widths]) + plt.barh( + neg_inds, + np.array(neg_widths) + label_padding - 0.02 * dataw, + left=np.array(neg_lefts) + 0.01 * dataw, + color=BLUE.hex, + alpha=0, + ) + + # define variable we need for plotting the arrows + head_length = 0.08 + bar_width = 0.8 + xlen = plt.xlim()[1] - plt.xlim()[0] + fig = plt.gcf() + ax = plt.gca() + bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + width = bbox.width + bbox_to_xscale = xlen / width + hl_scaled = bbox_to_xscale * head_length + renderer = fig.canvas.get_renderer() + + # draw the positive arrows + for i in range(len(pos_inds)): + dist = pos_widths[i] + arrow_obj = plt.arrow( + pos_lefts[i], + pos_inds[i], + max(dist - hl_scaled, 0.000001), + 0, + head_length=min(dist, hl_scaled), + color=RED.hex, + width=bar_width, + head_width=bar_width, + ) + + if pos_low is not None and i < len(pos_low): + plt.errorbar( + pos_lefts[i] + pos_widths[i], + pos_inds[i], + xerr=np.array([[pos_widths[i] - pos_low[i]], [pos_high[i] - pos_widths[i]]]), + ecolor=BLUE.hex, + ) + + txt_obj = plt.text( + pos_lefts[i] + 0.5 * dist, + pos_inds[i], + format_value(pos_widths[i], "%+0.02f"), + horizontalalignment="center", + verticalalignment="center", + color="white", + fontsize=12, + ) + text_bbox = txt_obj.get_window_extent(renderer=renderer) + arrow_bbox = arrow_obj.get_window_extent(renderer=renderer) + + # if the text overflows the arrow then draw it after the arrow + if text_bbox.width > arrow_bbox.width: + txt_obj.remove() + + txt_obj = plt.text( + pos_lefts[i] + (5 / 72) * bbox_to_xscale + dist, + pos_inds[i], + format_value(pos_widths[i], "%+0.02f"), + horizontalalignment="left", + verticalalignment="center", + color=RED.hex, + fontsize=12, + ) + + # draw the negative arrows + for i in range(len(neg_inds)): + dist = neg_widths[i] + + arrow_obj = plt.arrow( + neg_lefts[i], + neg_inds[i], + -max(-dist - hl_scaled, 0.000001), + 0, + head_length=min(-dist, hl_scaled), + color=BLUE.hex, + width=bar_width, + head_width=bar_width, + ) + + if neg_low is not None and i < len(neg_low): + plt.errorbar( + neg_lefts[i] + neg_widths[i], + neg_inds[i], + xerr=np.array([[neg_widths[i] - neg_low[i]], [neg_high[i] - neg_widths[i]]]), + ecolor=RED.hex, + ) + + txt_obj = plt.text( + neg_lefts[i] + 0.5 * dist, + neg_inds[i], + format_value(neg_widths[i], "%+0.02f"), + horizontalalignment="center", + verticalalignment="center", + color="white", + fontsize=12, + ) + text_bbox = txt_obj.get_window_extent(renderer=renderer) + arrow_bbox = arrow_obj.get_window_extent(renderer=renderer) + + # if the text overflows the arrow then draw it after the arrow + if text_bbox.width > arrow_bbox.width: + txt_obj.remove() + + txt_obj = plt.text( + neg_lefts[i] - (5 / 72) * bbox_to_xscale + dist, + neg_inds[i], + format_value(neg_widths[i], "%+0.02f"), + horizontalalignment="right", + verticalalignment="center", + color=BLUE.hex, + fontsize=12, + ) + + # draw the y-ticks twice, once in gray and then again with just the feature names in black + # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks + ytick_pos = list(range(num_features)) + list(np.arange(num_features) + 1e-8) + plt.yticks( + ytick_pos, + yticklabels[:-1] + [label.split("=")[-1] for label in yticklabels[:-1]], + fontsize=13, + ) + + # put horizontal lines for each feature row + for i in range(num_features): + plt.axhline(i, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1) + + # mark the prior expected value and the model prediction + plt.axvline( + base_values, 0, 1 / num_features, color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1 + ) + fx = base_values + values.sum() + plt.axvline(fx, 0, 1, color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1) + + # clean up the main axis + plt.gca().xaxis.set_ticks_position("bottom") + plt.gca().yaxis.set_ticks_position("none") + plt.gca().spines["right"].set_visible(False) + plt.gca().spines["top"].set_visible(False) + plt.gca().spines["left"].set_visible(False) + ax.tick_params(labelsize=13) + # plt.xlabel("\nModel output", fontsize=12) + + # draw the E[f(X)] tick mark + xmin, xmax = ax.get_xlim() + ax2 = ax.twiny() + ax2.set_xlim(xmin, xmax) + ax2.set_xticks( + [base_values, base_values + 1e-8] + ) # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks + ax2.set_xticklabels( + ["\n$E[f(X)]$", "\n$ = " + format_value(base_values, "%0.03f") + "$"], + fontsize=12, + ha="left", + ) + ax2.spines["right"].set_visible(False) + ax2.spines["top"].set_visible(False) + ax2.spines["left"].set_visible(False) + + # draw the f(x) tick mark + ax3 = ax2.twiny() + ax3.set_xlim(xmin, xmax) + # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks + ax3.set_xticks([base_values + values.sum(), base_values + values.sum() + 1e-8]) + ax3.set_xticklabels( + ["$f(x)$", "$ = " + format_value(fx, "%0.03f") + "$"], fontsize=12, ha="left" + ) + tick_labels = ax3.xaxis.get_majorticklabels() + tick_labels[0].set_transform( + tick_labels[0].get_transform() + + matplotlib.transforms.ScaledTranslation(-10 / 72.0, 0, fig.dpi_scale_trans) + ) + tick_labels[1].set_transform( + tick_labels[1].get_transform() + + matplotlib.transforms.ScaledTranslation(12 / 72.0, 0, fig.dpi_scale_trans) + ) + tick_labels[1].set_color("#999999") + ax3.spines["right"].set_visible(False) + ax3.spines["top"].set_visible(False) + ax3.spines["left"].set_visible(False) + + # adjust the position of the E[f(X)] = x.xx label + tick_labels = ax2.xaxis.get_majorticklabels() + tick_labels[0].set_transform( + tick_labels[0].get_transform() + + matplotlib.transforms.ScaledTranslation(-20 / 72.0, 0, fig.dpi_scale_trans) + ) + tick_labels[1].set_transform( + tick_labels[1].get_transform() + + matplotlib.transforms.ScaledTranslation(22 / 72.0, -1 / 72.0, fig.dpi_scale_trans) + ) + + tick_labels[1].set_color("#999999") + + # color the y tick labels that have the feature values as gray + # (these fall behind the black ones with just the feature name) + tick_labels = ax.yaxis.get_majorticklabels() + for i in range(num_features): + tick_labels[i].set_color("#999999") + + if show: + plt.show() + else: + return plt.gca() + + def waterfall_plot( interaction_values: InteractionValues, - feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, show: bool = False, - abbreviate: bool = True, max_display: int = 10, ) -> Optional[plt.Axes]: - """Draws interaction values on a waterfall plot. + """Draws a waterfall plot with the interaction values. Note: Requires the ``shap`` Python package to be installed. Args: interaction_values: The interaction values as an interaction object. - feature_names: The feature names used for plotting. If no feature names are provided, the - feature indices are used instead. Defaults to ``None``. - feature_values: The feature values used for plotting. Defaults to ``None``. show: Whether to show the plot. Defaults to ``False``. - abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``. max_display: The maximum number of interactions to display. Defaults to ``10``. """ - check_import_module("shap") - import shap - - if interaction_values.max_order == 1: - shap_explanation = shap.Explanation( - values=interaction_values.get_n_order_values(1), - base_values=interaction_values.baseline_value, - data=feature_values, - feature_names=feature_names, - ) - else: - _shap_values, _labels = get_interaction_values_and_feature_names( - interaction_values, feature_names, feature_values, abbreviate=abbreviate - ) - - shap_explanation = shap.Explanation( - values=np.array(_shap_values), - base_values=np.array([interaction_values.baseline_value], dtype=float), - data=None, - feature_names=_labels, - ) + data = np.array( + [ + (" x ".join([str(f) for f in feature_tuple]), str(value)) + for feature_tuple, value in interaction_values.dict_values.items() + if len(feature_tuple) > 0 + ], + dtype=object, + ) + values = data[:, 1].astype(float) + feature_names = data[:, 0] - return shap.plots.waterfall(shap_explanation, max_display=max_display, show=show) + return _draw_waterfall_plot( + values, interaction_values.baseline_value, feature_names, max_display=max_display, show=show + ) From b5eea32a6c13b69a2e20d1729c117154335d1667 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 11:39:40 +0100 Subject: [PATCH 05/21] Introduces formal_value function from shap to have identical looking plots. Small Refactor of force --- shapiq/plot/bar.py | 12 +----------- shapiq/plot/force.py | 2 +- shapiq/plot/utils.py | 13 ++++++++++++- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index fc4d18b1..5f4ecfa5 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -1,6 +1,5 @@ """Wrapper for the bar plot from the ``shap`` package.""" -import re from typing import Optional import matplotlib.pyplot as plt @@ -8,20 +7,11 @@ from ..interaction_values import InteractionValues from ._config import BLUE, RED +from .utils import format_value __all__ = ["bar_plot"] -def format_value(s, format_str): - """Strips trailing zeros and uses a unicode minus sign.""" - if not issubclass(type(s), str): - s = format_str % s - s = re.sub(r"\.?0+$", "", s) - if s[0] == "-": - s = "\u2212" + s[1:] - return s - - def _bar(values, feature_names, max_display=10, ax=None, show=True): """Create a bar plot of a set of SHAP values. diff --git a/shapiq/plot/force.py b/shapiq/plot/force.py index de6e83c7..e908965a 100644 --- a/shapiq/plot/force.py +++ b/shapiq/plot/force.py @@ -430,7 +430,7 @@ def _draw_force_plot( # Compute overall metrics base_value = interaction_value.baseline_value - out_value = np.sum(interaction_value.values) # TODO: Must be the value of the grand coalition + out_value = np.sum(interaction_value.values) # Sum of all values with the baseline value # Format data feature_to_names = {i: name for i, name in enumerate(feature_names)} dict_values = interaction_value.dict_values diff --git a/shapiq/plot/utils.py b/shapiq/plot/utils.py index d5c6a117..629209eb 100644 --- a/shapiq/plot/utils.py +++ b/shapiq/plot/utils.py @@ -1,6 +1,7 @@ """This utility module contains helper functions for plotting.""" import copy +import re from collections.abc import Iterable from typing import Optional @@ -9,7 +10,17 @@ from ..interaction_values import InteractionValues from ..utils import powerset -__all__ = ["get_interaction_values_and_feature_names", "abbreviate_feature_names"] +__all__ = ["get_interaction_values_and_feature_names", "abbreviate_feature_names", "format_value"] + + +def format_value(s, format_str): + """Strips trailing zeros and uses a unicode minus sign.""" + if not issubclass(type(s), str): + s = format_str % s + s = re.sub(r"\.?0+$", "", s) + if s[0] == "-": + s = "\u2212" + s[1:] + return s def get_interaction_values_and_feature_names( From f7f96f579cb25accceffa29d17318e5bd63cc3f1 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 12:23:41 +0100 Subject: [PATCH 06/21] Removed errors in plot functions. Re-added some features used in interaction_values.py for the waterfall,force and bar plot. --- shapiq/interaction_values.py | 12 +---------- shapiq/plot/bar.py | 31 ++++++++++++++++++++--------- shapiq/plot/force.py | 18 ++++++++++++++++- shapiq/plot/watefall.py | 31 +++++++++++++++++++++++------ tests/tests_plots/test_force.py | 7 +------ tests/tests_plots/test_waterfall.py | 7 +------ 6 files changed, 67 insertions(+), 39 deletions(-) diff --git a/shapiq/interaction_values.py b/shapiq/interaction_values.py index ae98d8fa..83b5d782 100644 --- a/shapiq/interaction_values.py +++ b/shapiq/interaction_values.py @@ -682,18 +682,13 @@ def plot_stacked_bar( def plot_force( self, feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, - matplotlib=True, show: bool = True, abbreviate: bool = True, - **kwargs, ) -> Optional[plt.Figure]: """Visualize InteractionValues on a force plot. For arguments, see shapiq.plots.force_plot(). - Requires the ``shap`` Python package to be installed. - Args: feature_names: The feature names used for plotting. If no feature names are provided, the feature indices are used instead. Defaults to ``None``. @@ -710,18 +705,14 @@ def plot_force( return force_plot( self, - feature_values=feature_values, feature_names=feature_names, - matplotlib=matplotlib, show=show, abbreviate=abbreviate, - **kwargs, ) def plot_waterfall( self, feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, show: bool = True, abbreviate: bool = True, max_display: int = 10, @@ -743,11 +734,10 @@ def plot_waterfall( return waterfall_plot( self, - feature_values=feature_values, feature_names=feature_names, show=show, - abbreviate=abbreviate, max_display=max_display, + abbreviate=abbreviate, ) def plot_sentence( diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index 5f4ecfa5..1eeb841d 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -75,7 +75,7 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): max_display = min(max_display, num_features) # Make it descending order - feature_order = np.argsort(values)[0][::-1] + feature_order = np.argsort(np.mean(values, axis=0))[::-1] y_pos = np.arange(len(feature_order), 0, -1) @@ -88,7 +88,10 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): # compute our figure size based on how many features we are showing fig = plt.gcf() row_height = 0.5 - fig.set_size_inches(8, num_features * row_height * np.sqrt(len(values)) + 1.5) + fig.set_size_inches( + 8 + 0.3 * max([len(f) for f in feature_names]), + num_features * row_height * np.sqrt(len(values)) + 1.5, + ) # if negative values are present then we draw a vertical line to mark 0, otherwise the axis does this for us... negative_values_present = np.sum(values[:, feature_order[:num_features]] < 0) > 0 @@ -193,13 +196,13 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): return ax -def default_feature_name(feature_tuple): +def format_labels(feature_mapping, feature_tuple): if len(feature_tuple) == 0: return "Basevalue" elif len(feature_tuple) == 1: - return "Feature " + str(feature_tuple[0]) + return str(feature_mapping[feature_tuple[0]]) else: - return " x ".join([str(f) for f in feature_tuple]) + return " x ".join([feature_mapping[f] for f in feature_tuple]) def bar_plot( @@ -223,14 +226,24 @@ def bar_plot( **kwargs: Keyword arguments passed to ``shap.plots.beeswarm()``. """ + n_players = list_of_interaction_values[0].n_players + + if feature_names is not None: + feature_mapping = {i: feature_names[i] for i in range(n_players)} + else: + feature_mapping = {i: str(i) for i in range(n_players)} + assert len(np.unique([iv.max_order for iv in list_of_interaction_values])) == 1 values = np.stack([iv.values for iv in list_of_interaction_values]) - labels = ( - np.array(list(map(default_feature_name, list_of_interaction_values[0].dict_values.keys()))) - if feature_names is None - else feature_names + labels = np.array( + list( + map( + lambda x: format_labels(feature_mapping, x), + list_of_interaction_values[0].dict_values.keys(), + ) + ) ) ax = _bar(values=values, feature_names=labels, show=False) diff --git a/shapiq/plot/force.py b/shapiq/plot/force.py index e908965a..1f760546 100644 --- a/shapiq/plot/force.py +++ b/shapiq/plot/force.py @@ -11,6 +11,7 @@ from matplotlib.path import Path from ..interaction_values import InteractionValues +from .utils import abbreviate_feature_names __all__ = ["force_plot"] @@ -561,7 +562,22 @@ def force_plot( interaction_values: InteractionValues, feature_names: Optional[np.ndarray] = None, show: bool = False, -): + abbreviate: bool = True, +) -> Optional[plt.Figure]: + """ + Draw a force plot. + Args: + interaction_values: + feature_names: + show: + abbreviate: + + Returns: + + """ if feature_names is None: feature_names = np.array([str(i) for i in range(interaction_values.n_players)]) + if abbreviate: + feature_names = abbreviate_feature_names(feature_names) + return _draw_force_plot(interaction_values, feature_names, figsize=(20, 3), show=show) diff --git a/shapiq/plot/watefall.py b/shapiq/plot/watefall.py index 704baf94..21b18960 100644 --- a/shapiq/plot/watefall.py +++ b/shapiq/plot/watefall.py @@ -8,7 +8,7 @@ from ..interaction_values import InteractionValues from ._config import BLUE, RED -from .utils import format_value +from .utils import abbreviate_feature_names, format_value __all__ = ["waterfall_plot"] @@ -84,7 +84,7 @@ def _draw_waterfall_plot( # add a last grouped feature to represent the impact of all the features we didn't show if num_features < len(values): - yticklabels[0] = "%d other features" % (len(values) - num_features + 1) + yticklabels[0] = "%d other features".format() remaining_impact = base_values - loc if remaining_impact < 0: pos_inds.append(0) @@ -322,29 +322,48 @@ def _draw_waterfall_plot( return plt.gca() +def format_labels(feature_mapping, feature_tuple): + if len(feature_tuple) == 0: + return "Basevalue" + elif len(feature_tuple) == 1: + return str(feature_mapping[feature_tuple[0]]) + else: + return " x ".join([feature_mapping[f] for f in feature_tuple]) + + def waterfall_plot( interaction_values: InteractionValues, + feature_names: Optional[np.ndarray[str]] = None, show: bool = False, max_display: int = 10, + abbreviate: bool = True, ) -> Optional[plt.Axes]: """Draws a waterfall plot with the interaction values. - Note: - Requires the ``shap`` Python package to be installed. - Args: interaction_values: The interaction values as an interaction object. + feature_names: The names of the features. Defaults to ``None``. show: Whether to show the plot. Defaults to ``False``. max_display: The maximum number of interactions to display. Defaults to ``10``. + abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. """ + + if feature_names is None: + feature_mapping = {i: str(i) for i in range(interaction_values.n_players)} + else: + if abbreviate: + feature_names = abbreviate_feature_names(feature_names) + feature_mapping = {i: feature_names[i] for i in range(interaction_values.n_players)} + data = np.array( [ - (" x ".join([str(f) for f in feature_tuple]), str(value)) + (format_labels(feature_mapping, feature_tuple), str(value)) for feature_tuple, value in interaction_values.dict_values.items() if len(feature_tuple) > 0 ], dtype=object, ) + values = data[:, 1].astype(float) feature_names = data[:, 0] diff --git a/tests/tests_plots/test_force.py b/tests/tests_plots/test_force.py index 3803437e..b17290d6 100644 --- a/tests/tests_plots/test_force.py +++ b/tests/tests_plots/test_force.py @@ -13,7 +13,6 @@ def test_force_plot(interaction_values_list: list[InteractionValues]): n_players = iv.n_players feature_names = [f"feature-{i}" for i in range(n_players)] feature_names = np.array(feature_names) - feature_values = np.array([i for i in range(n_players)]) fp = force_plot(iv, show=False) assert fp is not None @@ -25,11 +24,7 @@ def test_force_plot(interaction_values_list: list[InteractionValues]): assert isinstance(fp, plt.Figure) plt.close() - fp = force_plot(iv, show=False, feature_names=feature_names, feature_values=feature_values) - assert isinstance(fp, plt.Figure) - plt.close() - - fp = force_plot(iv, show=False, feature_names=None, feature_values=feature_values) + fp = force_plot(iv, show=False, feature_names=feature_names) assert isinstance(fp, plt.Figure) plt.close() diff --git a/tests/tests_plots/test_waterfall.py b/tests/tests_plots/test_waterfall.py index bedc047e..55f2d961 100644 --- a/tests/tests_plots/test_waterfall.py +++ b/tests/tests_plots/test_waterfall.py @@ -13,18 +13,13 @@ def test_waterfall_plot(interaction_values_list: list[InteractionValues]): n_players = iv.n_players feature_names = [f"feature-{i}" for i in range(n_players)] feature_names = np.array(feature_names) - feature_values = np.array([i for i in range(n_players)]) wp = waterfall_plot(iv, show=False) assert wp is not None assert isinstance(wp, plt.Axes) plt.close() - wp = waterfall_plot(iv, show=False, feature_names=feature_names, feature_values=feature_values) - assert isinstance(wp, plt.Axes) - plt.close() - - wp = waterfall_plot(iv, show=False, feature_names=None, feature_values=feature_values) + wp = waterfall_plot(iv, show=False, feature_names=feature_names) assert isinstance(wp, plt.Axes) plt.close() From 7283bbc41ae561eb4360a68979ef952d1ba59bda Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 12:45:17 +0100 Subject: [PATCH 07/21] Refactor force. --- shapiq/plot/force.py | 118 +++++++++++++++++++++++++++---------------- 1 file changed, 75 insertions(+), 43 deletions(-) diff --git a/shapiq/plot/force.py b/shapiq/plot/force.py index 1f760546..33f8f6cc 100644 --- a/shapiq/plot/force.py +++ b/shapiq/plot/force.py @@ -11,7 +11,7 @@ from matplotlib.path import Path from ..interaction_values import InteractionValues -from .utils import abbreviate_feature_names +from .utils import abbreviate_feature_names, format_labels __all__ = ["force_plot"] @@ -403,44 +403,22 @@ def update_axis_limits( spine.set_visible(False) -def _draw_force_plot( - interaction_value: InteractionValues, - feature_names: np.ndarray, - figsize: tuple[int, int], - show: bool = True, - text_rotation: float = 0, - min_perc: float = 0.05, -): +def _split_features( + interaction_dictionary: dict[tuple[int, ...], float], + feature_to_names: dict[int, str], + out_value: float, +) -> tuple[np.ndarray, np.ndarray, float, float]: """ - Draw the force plot. + Split the features into positive and negative values. Args: - interaction_value: Interactiovalues ot be plotted - feature_names: names of the features - figsize: size of the figure - show: Whether to show the plot - text_rotation: Amount of text rotation - min_perc: Define the minimum percentage of the total effect that a feature must contribute to be shown. - Defaults to 0.05. - - Returns: None - + interaction_dictionary: Dictionary of the interaction values """ - # Turn off interactive plot - if show is False: - plt.ioff() - - # Compute overall metrics - base_value = interaction_value.baseline_value - out_value = np.sum(interaction_value.values) # Sum of all values with the baseline value - # Format data - feature_to_names = {i: name for i, name in enumerate(feature_names)} - dict_values = interaction_value.dict_values pos_features = np.array( sorted( [ - [str(values), " x ".join([feature_to_names[f] for f in features])] - for features, values in dict_values.items() - if values >= 0 and len(features) > 0 + [str(value), format_labels(feature_to_names, coaltion)] + for coaltion, value in interaction_dictionary.items() + if value >= 0 and len(coaltion) > 0 ], key=lambda x: x[0], reverse=True, @@ -450,9 +428,9 @@ def _draw_force_plot( neg_features = np.array( sorted( [ - [str(values), " x ".join([feature_to_names[f] for f in features])] - for features, values in dict_values.items() - if values < 0 and len(features) > 0 + [str(value), " x ".join([feature_to_names[f] for f in coaltion])] + for coaltion, value in interaction_dictionary.items() + if value < 0 and len(coaltion) > 0 ], key=lambda x: x[0], reverse=True, @@ -487,18 +465,25 @@ def _draw_force_plot( else: total_pos = 0 - # Define plots - offset_text = (np.abs(total_neg) + np.abs(total_pos)) * 0.04 + return pos_features, neg_features, total_pos, total_neg - fig, ax = plt.subplots(figsize=figsize) - # Compute axis limit - update_axis_limits(ax, total_pos, pos_features, total_neg, neg_features, base_value, out_value) +def _add_bars( + ax: plt.Axes, out_value: float, pos_features: np.ndarray, neg_features: np.ndarray +) -> None: + """ + Add bars to the plot. + Args: + ax: Axes of the plot + out_value: grand total value + pos_features: positive features + neg_features: negative features - # Define width of bar + Returns: + + """ width_bar = 0.1 width_separators = (ax.get_xlim()[1] - ax.get_xlim()[0]) / 200 - # Create bar for negative shap values rectangle_list, separator_list = _create_bars( out_value, neg_features, "negative", width_separators, width_bar @@ -519,6 +504,53 @@ def _draw_force_plot( for i in separator_list: ax.add_patch(i) + +def _draw_force_plot( + interaction_value: InteractionValues, + feature_names: np.ndarray, + figsize: tuple[int, int], + show: bool = True, + text_rotation: float = 0, + min_perc: float = 0.05, +): + """ + Draw the force plot. + Args: + interaction_value: Interactiovalues ot be plotted + feature_names: names of the features + figsize: size of the figure + show: Whether to show the plot + text_rotation: Amount of text rotation + min_perc: Define the minimum percentage of the total effect that a feature must contribute to be shown. + Defaults to 0.05. + + Returns: None + + """ + # Turn off interactive plot + if show is False: + plt.ioff() + + # Compute overall metrics + base_value = interaction_value.baseline_value + out_value = np.sum(interaction_value.values) # Sum of all values with the baseline value + + # Split features into positive and negative values + pos_features, neg_features, total_pos, total_neg = _split_features( + interaction_value.dict_values, {i: name for i, name in enumerate(feature_names)}, out_value + ) + + # Define plots + offset_text = (np.abs(total_neg) + np.abs(total_pos)) * 0.04 + + fig, ax = plt.subplots(figsize=figsize) + + # Compute axis limit + update_axis_limits(ax, total_pos, pos_features, total_neg, neg_features, base_value, out_value) + + # Add the bars to the plot + _add_bars(ax, out_value, pos_features, neg_features) + # Add labels total_effect = np.abs(total_neg) + total_pos fig, ax = _add_labels( From ca42da9e12b639836d7fb70207d26bb1fbb6e180 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 12:45:30 +0100 Subject: [PATCH 08/21] Refactored abbreviation and label creation --- shapiq/plot/bar.py | 13 +++---------- shapiq/plot/utils.py | 9 +++++++++ shapiq/plot/watefall.py | 11 +---------- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index 1eeb841d..f8c75798 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -7,7 +7,7 @@ from ..interaction_values import InteractionValues from ._config import BLUE, RED -from .utils import format_value +from .utils import abbreviate_feature_names, format_labels, format_value __all__ = ["bar_plot"] @@ -196,15 +196,6 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): return ax -def format_labels(feature_mapping, feature_tuple): - if len(feature_tuple) == 0: - return "Basevalue" - elif len(feature_tuple) == 1: - return str(feature_mapping[feature_tuple[0]]) - else: - return " x ".join([feature_mapping[f] for f in feature_tuple]) - - def bar_plot( list_of_interaction_values: list[InteractionValues], feature_names: Optional[np.ndarray] = None, @@ -229,6 +220,8 @@ def bar_plot( n_players = list_of_interaction_values[0].n_players if feature_names is not None: + if abbreviate: + feature_names = abbreviate_feature_names(feature_names) feature_mapping = {i: feature_names[i] for i in range(n_players)} else: feature_mapping = {i: str(i) for i in range(n_players)} diff --git a/shapiq/plot/utils.py b/shapiq/plot/utils.py index 629209eb..88c1567a 100644 --- a/shapiq/plot/utils.py +++ b/shapiq/plot/utils.py @@ -23,6 +23,15 @@ def format_value(s, format_str): return s +def format_labels(feature_mapping, feature_tuple): + if len(feature_tuple) == 0: + return "Baseval." + elif len(feature_tuple) == 1: + return str(feature_mapping[feature_tuple[0]]) + else: + return " x ".join([feature_mapping[f] for f in feature_tuple]) + + def get_interaction_values_and_feature_names( interaction_values: InteractionValues, feature_names: Optional[np.ndarray] = None, diff --git a/shapiq/plot/watefall.py b/shapiq/plot/watefall.py index 21b18960..fe286da7 100644 --- a/shapiq/plot/watefall.py +++ b/shapiq/plot/watefall.py @@ -8,7 +8,7 @@ from ..interaction_values import InteractionValues from ._config import BLUE, RED -from .utils import abbreviate_feature_names, format_value +from .utils import abbreviate_feature_names, format_labels, format_value __all__ = ["waterfall_plot"] @@ -322,15 +322,6 @@ def _draw_waterfall_plot( return plt.gca() -def format_labels(feature_mapping, feature_tuple): - if len(feature_tuple) == 0: - return "Basevalue" - elif len(feature_tuple) == 1: - return str(feature_mapping[feature_tuple[0]]) - else: - return " x ".join([feature_mapping[f] for f in feature_tuple]) - - def waterfall_plot( interaction_values: InteractionValues, feature_names: Optional[np.ndarray[str]] = None, From 7526c7c0e73e8eea138e4d56e7a583710f99cfbd Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 13:07:53 +0100 Subject: [PATCH 09/21] Refactor force. Add test of force with concrete example --- shapiq/plot/force.py | 33 ------------------------ tests/tests_plots/test_force.py | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/shapiq/plot/force.py b/shapiq/plot/force.py index 33f8f6cc..28613286 100644 --- a/shapiq/plot/force.py +++ b/shapiq/plot/force.py @@ -318,39 +318,6 @@ def _add_base_value(base_value: float, ax: plt.Axes) -> None: text_out_val.set_bbox(dict(facecolor="white", edgecolor="white")) -def draw_higher_lower_element(out_value, offset_text): - plt.text( - out_value - offset_text, - 0.405, - "higher", - fontsize=13, - color="#FF0D57", - horizontalalignment="right", - ) - - plt.text( - out_value + offset_text, - 0.405, - "lower", - fontsize=13, - color="#1E88E5", - horizontalalignment="left", - ) - - plt.text( - out_value, 0.4, r"$\leftarrow$", fontsize=13, color="#1E88E5", horizontalalignment="center" - ) - - plt.text( - out_value, - 0.425, - r"$\rightarrow$", - fontsize=13, - color="#FF0D57", - horizontalalignment="center", - ) - - def update_axis_limits( ax: plt.Axes, total_pos: float, diff --git a/tests/tests_plots/test_force.py b/tests/tests_plots/test_force.py index b17290d6..75cc00ce 100644 --- a/tests/tests_plots/test_force.py +++ b/tests/tests_plots/test_force.py @@ -3,10 +3,55 @@ import matplotlib.pyplot as plt import numpy as np +import shapiq from shapiq.interaction_values import InteractionValues from shapiq.plot import force_plot +def test_force_concret(): + class CookingGame(shapiq.Game): + def __init__(self): + self.characteristic_function = { + (): 10, + (0,): 4, + (1,): 3, + (2,): 2, + (0, 1): 9, + (0, 2): 8, + (1, 2): 7, + (0, 1, 2): 15, + } + super().__init__( + n_players=3, + player_names=["Alice", "Bob", "Charlie"], # Optional list of names + normalization_value=self.characteristic_function[()], # 0 + normalize=False, + ) + + def value_function(self, coalitions: np.ndarray) -> np.ndarray: + """Defines the worth of a coalition as a lookup in the characteristic function.""" + output = [] + for coalition in coalitions: + output.append(self.characteristic_function[tuple(np.where(coalition)[0])]) + return np.array(output) + + cooking_game = CookingGame() + + from shapiq import ExactComputer + + # create an ExactComputer object for the cooking game + exact_computer = ExactComputer(n_players=cooking_game.n_players, game_fun=cooking_game) + + # compute the Shapley Values for the game + sv_exact = exact_computer(index="k-SII", order=2) + print(sv_exact.dict_values) + + # visualize the Shapley Values + from shapiq.plot import force_plot + + force_plot(sv_exact, show=True) + + def test_force_plot(interaction_values_list: list[InteractionValues]): """Test the force plot function.""" iv = interaction_values_list[0] From f03416f1a4765db316d1028317b723ec5e1e67ad Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 13:14:42 +0100 Subject: [PATCH 10/21] Refactor bar. Add test of bar with concrete example --- shapiq/plot/bar.py | 11 ++------ tests/tests_plots/test_bar.py | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index f8c75798..bf9283d1 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -59,13 +59,6 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): """ # assert str(type(shap_values)).endswith("Explanation'>"), "The shap_values parameter must be a shap.Explanation object!" - - # ensure we at least have default feature names - if feature_names is None: - feature_names = np.array([f"Feature {i}" for i in range(len(values[0]))]) - if issubclass(type(feature_names), str): - feature_names = [i + " " + feature_names for i in range(len(values[0]))] - xlabel = "Shapley value" # determine how many top features we will plot @@ -201,7 +194,7 @@ def bar_plot( feature_names: Optional[np.ndarray] = None, show: bool = False, abbreviate: bool = True, - **kwargs, + max_display: Optional[int] = 10, ) -> Optional[plt.Axes]: """Draws interaction values on a bar plot. @@ -239,7 +232,7 @@ def bar_plot( ) ) - ax = _bar(values=values, feature_names=labels, show=False) + ax = _bar(values=values, feature_names=labels, show=False, max_display=max_display) ax.set_xlabel("Shapley value") if not show: return ax diff --git a/tests/tests_plots/test_bar.py b/tests/tests_plots/test_bar.py index 4e980045..a03d3945 100644 --- a/tests/tests_plots/test_bar.py +++ b/tests/tests_plots/test_bar.py @@ -3,10 +3,55 @@ import matplotlib.pyplot as plt import numpy as np +import shapiq from shapiq.interaction_values import InteractionValues from shapiq.plot import bar_plot +def test_bar_concret(): + class CookingGame(shapiq.Game): + def __init__(self): + self.characteristic_function = { + (): 10, + (0,): 4, + (1,): 3, + (2,): 2, + (0, 1): 9, + (0, 2): 8, + (1, 2): 7, + (0, 1, 2): 15, + } + super().__init__( + n_players=3, + player_names=["Alice", "Bob", "Charlie"], # Optional list of names + normalization_value=self.characteristic_function[()], # 0 + normalize=False, + ) + + def value_function(self, coalitions: np.ndarray) -> np.ndarray: + """Defines the worth of a coalition as a lookup in the characteristic function.""" + output = [] + for coalition in coalitions: + output.append(self.characteristic_function[tuple(np.where(coalition)[0])]) + return np.array(output) + + cooking_game = CookingGame() + + from shapiq import ExactComputer + + # create an ExactComputer object for the cooking game + exact_computer = ExactComputer(n_players=cooking_game.n_players, game_fun=cooking_game) + + # compute the Shapley Values for the game + sv_exact = exact_computer(index="k-SII", order=2) + print(sv_exact.dict_values) + + # visualize the Shapley Values + from shapiq.plot import bar_plot + + bar_plot([sv_exact], show=True) + + def test_bar_plot(interaction_values_list: list[InteractionValues]): """Test the bar plot function.""" n_players = interaction_values_list[0].n_players @@ -27,3 +72,9 @@ def test_bar_plot(interaction_values_list: list[InteractionValues]): output = bar_plot(interaction_values_list, show=True) assert output is None plt.close("all") + + # test max_display=None + output = bar_plot(interaction_values_list, show=False, max_display=None) + assert output is not None + assert isinstance(output, plt.Axes) + plt.close("all") From 85ba7faf7a2e2d8ac134a2ab5db8b26c0c166a99 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 13:16:31 +0100 Subject: [PATCH 11/21] Add concrete test for waterfall --- tests/tests_plots/test_waterfall.py | 48 +++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/tests_plots/test_waterfall.py b/tests/tests_plots/test_waterfall.py index 55f2d961..9bd4c41d 100644 --- a/tests/tests_plots/test_waterfall.py +++ b/tests/tests_plots/test_waterfall.py @@ -7,6 +7,54 @@ from shapiq.plot import waterfall_plot +def test_waterfall_concrete(): + import numpy as np + + import shapiq + + class CookingGame(shapiq.Game): + def __init__(self): + self.characteristic_function = { + (): 10, + (0,): 4, + (1,): 3, + (2,): 2, + (0, 1): 9, + (0, 2): 8, + (1, 2): 7, + (0, 1, 2): 15, + } + super().__init__( + n_players=3, + player_names=["Alice", "Bob", "Charlie"], # Optional list of names + normalization_value=self.characteristic_function[()], # 0 + normalize=False, + ) + + def value_function(self, coalitions: np.ndarray) -> np.ndarray: + """Defines the worth of a coalition as a lookup in the characteristic function.""" + output = [] + for coalition in coalitions: + output.append(self.characteristic_function[tuple(np.where(coalition)[0])]) + return np.array(output) + + cooking_game = CookingGame() + + from shapiq import ExactComputer + + # create an ExactComputer object for the cooking game + exact_computer = ExactComputer(n_players=cooking_game.n_players, game_fun=cooking_game) + + # compute the Shapley Values for the game + sv_exact = exact_computer(index="k-SII", order=2) + print(sv_exact.dict_values) + + # visualize the Shapley Values + from shapiq.plot import waterfall_plot + + waterfall_plot(sv_exact, show=True) + + def test_waterfall_plot(interaction_values_list: list[InteractionValues]): """Test the waterfall plot function.""" iv = interaction_values_list[0] From 7f01854da4cceb12add6bebe8157f01619e11682 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 13:17:41 +0100 Subject: [PATCH 12/21] Removed conversion function from shapiq to shap as not necessary anymore --- shapiq/plot/utils.py | 55 +------------------------------------------- 1 file changed, 1 insertion(+), 54 deletions(-) diff --git a/shapiq/plot/utils.py b/shapiq/plot/utils.py index 88c1567a..3758bbfb 100644 --- a/shapiq/plot/utils.py +++ b/shapiq/plot/utils.py @@ -1,16 +1,9 @@ """This utility module contains helper functions for plotting.""" -import copy import re from collections.abc import Iterable -from typing import Optional -import numpy as np - -from ..interaction_values import InteractionValues -from ..utils import powerset - -__all__ = ["get_interaction_values_and_feature_names", "abbreviate_feature_names", "format_value"] +__all__ = ["abbreviate_feature_names", "format_value"] def format_value(s, format_str): @@ -32,52 +25,6 @@ def format_labels(feature_mapping, feature_tuple): return " x ".join([feature_mapping[f] for f in feature_tuple]) -def get_interaction_values_and_feature_names( - interaction_values: InteractionValues, - feature_names: Optional[np.ndarray] = None, - feature_values: Optional[np.ndarray] = None, - abbreviate: bool = True, -) -> tuple[np.ndarray, np.ndarray]: - """Converts higher-order interaction values to SHAP-like vectors with associated labels. - - Args: - interaction_values: The interaction values as an interaction object. - feature_names: The feature names used for plotting. If no feature names are provided, the - feature indices are used instead. Defaults to ``None``. - feature_values: The feature values used for plotting. Defaults to ``None``. - abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. - - Returns: - A tuple containing the SHAP values and the corresponding labels. - """ - feature_names = copy.deepcopy(feature_names) - if feature_names is not None and abbreviate: - feature_names = abbreviate_feature_names(feature_names) - _values_dict = {} - for i in range(1, interaction_values.max_order + 1): - _values_dict[i] = interaction_values.get_n_order_values(i) - _n_features = len(_values_dict[1]) - _shap_values = [] - _labels = [] - for interaction in powerset( - range(_n_features), min_size=1, max_size=interaction_values.max_order - ): - _order = len(interaction) - _values = _values_dict[_order] - _shap_values.append(_values[interaction]) - if feature_names is not None: - _name = " x ".join(str(feature_names[i]) for i in interaction) - else: - _name = " x ".join(f"{feature}" for feature in interaction) - if feature_values is not None: - _name += "\n" - _name += " x ".join(f"{feature_values[i]}".strip()[0:4] for i in interaction) - _labels.append(_name) - _shap_values = np.array(_shap_values) - _labels = np.array(_labels) - return _shap_values, _labels - - def abbreviate_feature_names(feature_names: Iterable[str]) -> list[str]: """A rudimentary function to abbreviate feature names for plotting. From de60c643a1ad4bd898fa198d4d6a9236e48bbc88 Mon Sep 17 00:00:00 2001 From: "Thies, Santo" Date: Sun, 5 Jan 2025 13:24:22 +0100 Subject: [PATCH 13/21] Removed imports of conversion function `get_interaction_values_and_feature_names`. --- shapiq/plot/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/shapiq/plot/__init__.py b/shapiq/plot/__init__.py index c5503158..3cb2a16c 100644 --- a/shapiq/plot/__init__.py +++ b/shapiq/plot/__init__.py @@ -7,7 +7,7 @@ from .si_graph import si_graph_plot from .stacked_bar import stacked_bar_plot from .upset import upset_plot -from .utils import abbreviate_feature_names, get_interaction_values_and_feature_names +from .utils import abbreviate_feature_names from .watefall import waterfall_plot __all__ = [ @@ -21,5 +21,4 @@ "upset_plot", # utils "abbreviate_feature_names", - "get_interaction_values_and_feature_names", ] From d1279fdbf4e9ec0c0f57d00e9b9b94ce5de6f811 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Wed, 8 Jan 2025 16:30:16 +0100 Subject: [PATCH 14/21] updated bar plot and added aggregation of InteractionValues object --- shapiq/interaction_values.py | 67 ++++++++++++++ shapiq/plot/bar.py | 126 +++++++++++++------------- shapiq/plot/utils.py | 2 +- tests/test_base_interaction_values.py | 55 ++++++++++- tests/tests_plots/test_bar.py | 2 +- 5 files changed, 188 insertions(+), 64 deletions(-) diff --git a/shapiq/interaction_values.py b/shapiq/interaction_values.py index 83b5d782..a83bcd34 100644 --- a/shapiq/interaction_values.py +++ b/shapiq/interaction_values.py @@ -769,3 +769,70 @@ def plot_upset(self, show: bool = True, **kwargs) -> Optional[plt.Figure]: from shapiq.plot.upset import upset_plot return upset_plot(self, show=show, **kwargs) + + +def aggregate_interaction_values( + interaction_values: list[InteractionValues], + aggregation: str = "mean", +) -> InteractionValues: + """Aggregates InteractionValues objects using a specific aggregation method. + + Args: + interaction_values: A list of InteractionValues objects to aggregate. + aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are + ``"median"``, ``"sum"``, ``"max"``, and ``"min"``. + + Returns: + The aggregated InteractionValues object. + + Note: + The index of the aggregated InteractionValues object is set to the index of the first + InteractionValues object in the list. + + Raises: + ValueError: If the aggregation method is not supported. + """ + + def _aggregate(vals: list[float], method: str) -> float: + """Does the actual aggregation of the values.""" + if method == "mean": + return np.mean(vals) + elif method == "median": + return np.median(vals) + elif method == "sum": + return np.sum(vals) + elif method == "max": + return np.max(vals) + elif method == "min": + return np.min(vals) + else: + raise ValueError(f"Aggregation method {method} is not supported.") + + # get all keys from all InteractionValues objects + all_keys = set() + for iv in interaction_values: + all_keys.update(iv.interaction_lookup.keys()) + + # aggregate the values + new_values = np.zeros(len(all_keys), dtype=float) + new_lookup = {} + for i, key in enumerate(all_keys): + new_lookup[key] = i + values = [iv[key] for iv in interaction_values] + new_values[i] = _aggregate(values, aggregation) + + max_order = max([iv.max_order for iv in interaction_values]) + min_order = min([iv.min_order for iv in interaction_values]) + n_players = max([iv.n_players for iv in interaction_values]) + + return InteractionValues( + values=new_values, + index=interaction_values[0].index, + max_order=max_order, + n_players=n_players, + min_order=min_order, + interaction_lookup=new_lookup, + estimated=True, + estimation_budget=None, + baseline_value=_aggregate([iv.baseline_value for iv in interaction_values], aggregation), + ) diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index bf9283d1..e41c4a27 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -5,14 +5,14 @@ import matplotlib.pyplot as plt import numpy as np -from ..interaction_values import InteractionValues +from ..interaction_values import InteractionValues, aggregate_interaction_values from ._config import BLUE, RED from .utils import abbreviate_feature_names, format_labels, format_value __all__ = ["bar_plot"] -def _bar(values, feature_names, max_display=10, ax=None, show=True): +def _bar(values, feature_names, max_display=10, ax=None): """Create a bar plot of a set of SHAP values. Parameters @@ -29,24 +29,8 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): explanation objects. max_display : int How many top features to include in the bar plot (default is 10). - order : OpChain or numpy.ndarray - A function that returns a sort ordering given a matrix of SHAP values - and an axis, or a direct sample ordering given as a ``numpy.ndarray``. - - By default, take the absolute value. - clustering: np.ndarray or None - A partition tree, as returned by ``shap.utils.hclust`` - clustering_cutoff: float - Controls how much of the clustering structure is displayed. - show_data: bool or str - Controls if data values are shown as part of the y tick labels. If - "auto", we show the data only when there are no transforms. ax: matplotlib Axes Axes object to draw the plot onto, otherwise uses the current Axes. - show : bool - Whether ``matplotlib.pyplot.show()`` is called before returning. - Setting this to ``False`` allows the plot - to be customized further after it has been created. Returns ------- @@ -58,36 +42,45 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): See `bar plot examples `_. """ - # assert str(type(shap_values)).endswith("Explanation'>"), "The shap_values parameter must be a shap.Explanation object!" - xlabel = "Shapley value" - # determine how many top features we will plot + num_features = len(values[0]) if max_display is None: - max_display = len(feature_names) - num_features = min(max_display, len(values[0])) + max_display = num_features max_display = min(max_display, num_features) + num_cut = max(num_features - max_display, 0) # number of features that are not displayed - # Make it descending order + # get order of features in descending order feature_order = np.argsort(np.mean(values, axis=0))[::-1] - y_pos = np.arange(len(feature_order), 0, -1) - - # build our y-tick labels - yticklabels = [feature_names[i] for i in feature_order] - + # if there are more features than we are displaying then we aggregate the features not shown + if num_cut > 0: + cut_feature_values = values[:, feature_order[max_display:]] + sum_of_remaining = np.sum(cut_feature_values, axis=None) + index_of_last = feature_order[max_display] + values = np.insert(values, index_of_last, sum_of_remaining, axis=1) + max_display += 1 # include the sum of the remaining in the display + + # get the top features and their names + feature_inds = feature_order[:max_display] + y_pos = np.arange(len(feature_inds), 0, -1) + yticklabels = [feature_names[i] for i in feature_inds] + if num_cut > 0: + yticklabels[-1] = f"Sum of {int(num_cut)} other features" + + # create a figure if one was not provided if ax is None: ax = plt.gca() - # Only modify the figure size if ax was not passed in + # only modify the figure size if ax was not passed in # compute our figure size based on how many features we are showing fig = plt.gcf() row_height = 0.5 fig.set_size_inches( - 8 + 0.3 * max([len(f) for f in feature_names]), - num_features * row_height * np.sqrt(len(values)) + 1.5, + 8 + 0.3 * max([len(feature_name) for feature_name in feature_names]), + max_display * row_height * np.sqrt(len(values)) + 1.5, ) - # if negative values are present then we draw a vertical line to mark 0, otherwise the axis does this for us... - negative_values_present = np.sum(values[:, feature_order[:num_features]] < 0) > 0 + # if negative values are present, we draw a vertical line to mark 0 + negative_values_present = np.sum(values[:, feature_order[:max_display]] < 0) > 0 if negative_values_present: ax.axvline(0, 0, 1, color="#000000", linestyle="-", linewidth=1, zorder=1) @@ -99,15 +92,15 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2) ax.barh( y_pos + ypos_offset, - values[i, feature_order], + values[i, feature_inds], bar_width, align="center", color=[ - BLUE.hex if values[i, feature_order[j]] <= 0 else RED.hex for j in range(len(y_pos)) + BLUE.hex if values[i, feature_inds[j]] <= 0 else RED.hex for j in range(len(y_pos)) ], hatch=patterns[i], edgecolor=(1, 1, 1, 0.8), - label="Model " + str(i), + label="Group " + str(i + 1), ) # draw the yticks (the 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks) @@ -118,19 +111,19 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ) xlen = ax.get_xlim()[1] - ax.get_xlim()[0] - # xticks = ax.get_xticks() bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) width = bbox.width bbox_to_xscale = xlen / width + # draw the bar labels as text next to the bars for i in range(len(values)): ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2) for j in range(len(y_pos)): - ind = feature_order[j] + ind = feature_inds[j] if values[i, ind] < 0: ax.text( values[i, ind] - (5 / 72) * bbox_to_xscale, - y_pos[j] + ypos_offset, + float(y_pos[j] + ypos_offset), format_value(values[i, ind], "%+0.02f"), horizontalalignment="right", verticalalignment="center", @@ -140,7 +133,7 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): else: ax.text( values[i, ind] + (5 / 72) * bbox_to_xscale, - y_pos[j] + ypos_offset, + float(y_pos[j] + ypos_offset), format_value(values[i, ind], "%+0.02f"), horizontalalignment="left", verticalalignment="center", @@ -149,9 +142,10 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ) # put horizontal lines for each feature row - for i in range(num_features): + for i in range(max_display): ax.axhline(i + 1, color="#888888", lw=0.5, dashes=(1, 5), zorder=-1) + # remove plot frame and y-axis ticks ax.xaxis.set_ticks_position("bottom") ax.yaxis.set_ticks_position("none") ax.spines["right"].set_visible(False) @@ -160,19 +154,15 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): ax.spines["left"].set_visible(False) ax.tick_params("x", labelsize=11) + # set the x-axis limits to cover the data xmin, xmax = ax.get_xlim() - ymin, ymax = ax.get_ylim() x_buffer = (xmax - xmin) * 0.05 - if negative_values_present: ax.set_xlim(xmin - x_buffer, xmax + x_buffer) else: ax.set_xlim(xmin, xmax + x_buffer) - # if features is None: - # pl.xlabel(labels["GLOBAL_VALUE"], fontsize=13) - # else: - ax.set_xlabel(xlabel, fontsize=13) + ax.set_xlabel("Attribution", fontsize=13) if len(values) > 1: ax.legend(fontsize=12, loc="lower right") @@ -180,13 +170,10 @@ def _bar(values, feature_names, max_display=10, ax=None, show=True): # color the y tick labels that have the feature values as gray # (these fall behind the black ones with just the feature name) tick_labels = ax.yaxis.get_majorticklabels() - for i in range(num_features): + for i in range(max_display): tick_labels[i].set_color("#999999") - if show: - plt.show() - else: - return ax + return ax def bar_plot( @@ -195,6 +182,7 @@ def bar_plot( show: bool = False, abbreviate: bool = True, max_display: Optional[int] = 10, + global_plot: bool = True, ) -> Optional[plt.Axes]: """Draws interaction values on a bar plot. @@ -207,9 +195,12 @@ def bar_plot( show: Whether ``matplotlib.pyplot.show()`` is called before returning. Default is ``True``. Setting this to ``False`` allows the plot to be customized further after it has been created. abbreviate: Whether to abbreviate the feature names. Defaults to ``True``. - **kwargs: Keyword arguments passed to ``shap.plots.beeswarm()``. + max_display: The maximum number of features to display. Defaults to ``10``. If set to + ``None``, all features are displayed. + global_plot: Weather to aggregate the values of the different InteractionValues objects + into a global explanation (``True``) or to plot them as separate bars (``False``). + Defaults to ``True``. """ - n_players = list_of_interaction_values[0].n_players if feature_names is not None: @@ -219,21 +210,34 @@ def bar_plot( else: feature_mapping = {i: str(i) for i in range(n_players)} - assert len(np.unique([iv.max_order for iv in list_of_interaction_values])) == 1 - - values = np.stack([iv.values for iv in list_of_interaction_values]) + if global_plot: + global_values = aggregate_interaction_values(list_of_interaction_values) + values = np.expand_dims(global_values.values, axis=0) + interaction_list = global_values.interaction_lookup.keys() + else: + all_interactions = set() + for iv in list_of_interaction_values: + all_interactions.update(iv.interaction_lookup.keys()) + all_interactions = sorted(all_interactions) + interaction_list = [] + values = np.zeros((len(list_of_interaction_values), len(all_interactions))) + for j, interaction in enumerate(all_interactions): + interaction_list.append(interaction) + for i, iv in enumerate(list_of_interaction_values): + values[i, j] = iv[interaction] + + # TODO: update this to be correct with the order of labels labels = np.array( list( map( lambda x: format_labels(feature_mapping, x), - list_of_interaction_values[0].dict_values.keys(), + interaction_list, ) ) ) - ax = _bar(values=values, feature_names=labels, show=False, max_display=max_display) - ax.set_xlabel("Shapley value") + ax = _bar(values=values, feature_names=labels, max_display=max_display) if not show: return ax plt.show() diff --git a/shapiq/plot/utils.py b/shapiq/plot/utils.py index 3758bbfb..2c45d65f 100644 --- a/shapiq/plot/utils.py +++ b/shapiq/plot/utils.py @@ -3,7 +3,7 @@ import re from collections.abc import Iterable -__all__ = ["abbreviate_feature_names", "format_value"] +__all__ = ["abbreviate_feature_names", "format_value", "format_labels"] def format_value(s, format_str): diff --git a/tests/test_base_interaction_values.py b/tests/test_base_interaction_values.py index 4e4e81bc..62e4bd9e 100644 --- a/tests/test_base_interaction_values.py +++ b/tests/test_base_interaction_values.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from shapiq.interaction_values import InteractionValues +from shapiq.interaction_values import InteractionValues, aggregate_interaction_values from shapiq.utils import powerset @@ -626,3 +626,56 @@ def test_subset(): assert subset_interaction_values.estimated == interaction_values.estimated assert subset_interaction_values.estimation_budget == interaction_values.estimation_budget assert subset_interaction_values.index == interaction_values.index + + +@pytest.mark.parametrize("aggregation", ["sum", "mean", "median", "max", "min"]) +def test_aggregation(aggregation): + + n_objects = 3 + + n, min_order, max_order = 5, 1, 3 + interaction_values_list = [] + for _ in range(n_objects): + values = np.random.rand(2**n - 1) + interaction_lookup = { + interaction: i for i, interaction in enumerate(powerset(range(n), min_order, max_order)) + } + interaction_values = InteractionValues( + values=values, + index="SII", + max_order=max_order, + n_players=n, + min_order=min_order, + interaction_lookup=interaction_lookup, + estimated=False, + estimation_budget=0, + baseline_value=0.0, + ) + interaction_values_list.append(interaction_values) + + aggregated_interaction_values = aggregate_interaction_values( + interaction_values_list, aggregation=aggregation + ) + + assert isinstance(aggregated_interaction_values, InteractionValues) + assert aggregated_interaction_values.index == "SII" + assert aggregated_interaction_values.n_players == n + assert aggregated_interaction_values.min_order == min_order + assert aggregated_interaction_values.max_order == max_order + + # check that all interactions are equal to the expected value + for interaction in powerset(range(n), 1, n): + aggregated_value = np.array( + [interaction_values[interaction] for interaction_values in interaction_values_list] + ) + if aggregation == "sum": + expected_value = np.sum(aggregated_value) + elif aggregation == "mean": + expected_value = np.mean(aggregated_value) + elif aggregation == "median": + expected_value = np.median(aggregated_value) + elif aggregation == "max": + expected_value = np.max(aggregated_value) + elif aggregation == "min": + expected_value = np.min(aggregated_value) + assert aggregated_interaction_values[interaction] == expected_value diff --git a/tests/tests_plots/test_bar.py b/tests/tests_plots/test_bar.py index a03d3945..befff43c 100644 --- a/tests/tests_plots/test_bar.py +++ b/tests/tests_plots/test_bar.py @@ -8,7 +8,7 @@ from shapiq.plot import bar_plot -def test_bar_concret(): +def test_bar_concrete(): class CookingGame(shapiq.Game): def __init__(self): self.characteristic_function = { From 56e3ef6c34df50f1e7065bbe5bc3fe55ccf4a6ba Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 16:17:48 +0100 Subject: [PATCH 15/21] updates aggregation method and finishes work on bar plot --- shapiq/interaction_values.py | 57 ++++++++++++++++++++++++++- shapiq/plot/bar.py | 55 +++++++------------------- shapiq/plot/utils.py | 49 ++++++++++++++++++++--- tests/test_base_interaction_values.py | 47 ++++++++++++++++++++++ tests/tests_plots/test_utils.py | 35 ++++++++++++++++ 5 files changed, 196 insertions(+), 47 deletions(-) create mode 100644 tests/tests_plots/test_utils.py diff --git a/shapiq/interaction_values.py b/shapiq/interaction_values.py index a83bcd34..cb1fb931 100644 --- a/shapiq/interaction_values.py +++ b/shapiq/interaction_values.py @@ -4,6 +4,7 @@ import copy import os import pickle +from collections.abc import Sequence from dataclasses import dataclass from typing import Optional, Union from warnings import warn @@ -630,6 +631,25 @@ def to_dict(self) -> dict: "baseline_value": self.baseline_value, } + def aggregate( + self, others: Sequence["InteractionValues"], aggregation: str = "mean" + ) -> "InteractionValues": + """Aggregates InteractionValues objects using a specific aggregation method. + + Args: + others: A list of InteractionValues objects to aggregate. + aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are + ``"median"``, ``"sum"``, ``"max"``, and ``"min"``. + + Returns: + The aggregated InteractionValues object. + + Note: + For documentation on the aggregation methods, see the ``aggregate_interaction_values()`` + function. + """ + return aggregate_interaction_values([self, *others], aggregation) + def plot_network(self, show: bool = True, **kwargs) -> Optional[tuple[plt.Figure, plt.Axes]]: """Visualize InteractionValues on a graph. @@ -772,7 +792,7 @@ def plot_upset(self, show: bool = True, **kwargs) -> Optional[plt.Figure]: def aggregate_interaction_values( - interaction_values: list[InteractionValues], + interaction_values: Sequence[InteractionValues], aggregation: str = "mean", ) -> InteractionValues: """Aggregates InteractionValues objects using a specific aggregation method. @@ -785,6 +805,37 @@ def aggregate_interaction_values( Returns: The aggregated InteractionValues object. + Example: + >>> iv1 = InteractionValues( + ... values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]), + ... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5}, + ... index="SII", + ... max_order=2, + ... n_players=3, + ... min_order=1, + ... baseline_value=0.0, + ... ) + >>> iv2 = InteractionValues( + ... values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]), # this iv is missing the (1, 2) value + ... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4}, # no (1, 2) + ... index="SII", + ... max_order=2, + ... n_players=3, + ... min_order=1, + ... baseline_value=1.0, + ... ) + >>> aggregate_interaction_values([iv1, iv2], "mean") + InteractionValues( + index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None, + n_players=3, baseline_value=0.5, + Top 10 interactions: + (1, 2): 0.60 + (0, 2): 0.35 + (0, 1): 0.25 + (0,): 0.15 + (1,): 0.25 + (2,): 0.35 + ) Note: The index of the aggregated InteractionValues object is set to the index of the first InteractionValues object in the list. @@ -812,6 +863,7 @@ def _aggregate(vals: list[float], method: str) -> float: all_keys = set() for iv in interaction_values: all_keys.update(iv.interaction_lookup.keys()) + all_keys = sorted(all_keys) # aggregate the values new_values = np.zeros(len(all_keys), dtype=float) @@ -824,6 +876,7 @@ def _aggregate(vals: list[float], method: str) -> float: max_order = max([iv.max_order for iv in interaction_values]) min_order = min([iv.min_order for iv in interaction_values]) n_players = max([iv.n_players for iv in interaction_values]) + baseline_value = _aggregate([iv.baseline_value for iv in interaction_values], aggregation) return InteractionValues( values=new_values, @@ -834,5 +887,5 @@ def _aggregate(vals: list[float], method: str) -> float: interaction_lookup=new_lookup, estimated=True, estimation_budget=None, - baseline_value=_aggregate([iv.baseline_value for iv in interaction_values], aggregation), + baseline_value=baseline_value, ) diff --git a/shapiq/plot/bar.py b/shapiq/plot/bar.py index e41c4a27..941590d2 100644 --- a/shapiq/plot/bar.py +++ b/shapiq/plot/bar.py @@ -12,35 +12,16 @@ __all__ = ["bar_plot"] -def _bar(values, feature_names, max_display=10, ax=None): +def _bar( + values: np.ndarray, + feature_names: np.ndarray, + max_display: Optional[int] = 10, + ax: Optional[plt.Axes] = None, +) -> plt.Axes: """Create a bar plot of a set of SHAP values. - Parameters - ---------- - shap_values : shap.Explanation or shap.Cohorts or dictionary of shap.Explanation objects - Passing a multi-row :class:`.Explanation` object creates a global - feature importance plot. - - Passing a single row of an explanation (i.e. ``shap_values[0]``) creates - a local feature importance plot. - - Passing a dictionary of Explanation objects will create a multiple-bar - plot with one bar type for each of the cohorts represented by the - explanation objects. - max_display : int - How many top features to include in the bar plot (default is 10). - ax: matplotlib Axes - Axes object to draw the plot onto, otherwise uses the current Axes. - - Returns - ------- - ax: matplotlib Axes - Returns the Axes object with the plot drawn onto it. Only returned if ``show=False``. - - Examples - -------- - See `bar plot examples `_. - + This is a modified version of the bar plot from the SHAP package. The original code can be found + at https://github.com/shap/shap. """ # determine how many top features we will plot num_features = len(values[0]) @@ -199,7 +180,8 @@ def bar_plot( ``None``, all features are displayed. global_plot: Weather to aggregate the values of the different InteractionValues objects into a global explanation (``True``) or to plot them as separate bars (``False``). - Defaults to ``True``. + Defaults to ``True``. If only one InteractionValues object is provided, this parameter + is ignored. """ n_players = list_of_interaction_values[0].n_players @@ -208,13 +190,14 @@ def bar_plot( feature_names = abbreviate_feature_names(feature_names) feature_mapping = {i: feature_names[i] for i in range(n_players)} else: - feature_mapping = {i: str(i) for i in range(n_players)} + feature_mapping = {i: "F" + str(i) for i in range(n_players)} + # aggregate the interaction values if global_plot is True if global_plot: global_values = aggregate_interaction_values(list_of_interaction_values) values = np.expand_dims(global_values.values, axis=0) interaction_list = global_values.interaction_lookup.keys() - else: + else: # plot the interaction values separately (also includes the case of a single object) all_interactions = set() for iv in list_of_interaction_values: all_interactions.update(iv.interaction_lookup.keys()) @@ -226,16 +209,8 @@ def bar_plot( for i, iv in enumerate(list_of_interaction_values): values[i, j] = iv[interaction] - # TODO: update this to be correct with the order of labels - - labels = np.array( - list( - map( - lambda x: format_labels(feature_mapping, x), - interaction_list, - ) - ) - ) + # format the labels + labels = [format_labels(feature_mapping, interaction) for interaction in interaction_list] ax = _bar(values=values, feature_names=labels, max_display=max_display) if not show: diff --git a/shapiq/plot/utils.py b/shapiq/plot/utils.py index 2c45d65f..664592f8 100644 --- a/shapiq/plot/utils.py +++ b/shapiq/plot/utils.py @@ -2,12 +2,30 @@ import re from collections.abc import Iterable +from typing import Union __all__ = ["abbreviate_feature_names", "format_value", "format_labels"] -def format_value(s, format_str): - """Strips trailing zeros and uses a unicode minus sign.""" +def format_value( + s: Union[float, str], + format_str: str = "%.2f", +) -> str: + """Strips trailing zeros and uses a unicode minus sign. + + Args: + s: The value to be formatted. + format_str: The format string to be used. Defaults to "%.2f". + + Returns: + str: The formatted value. + + Examples: + >>> format_value(1.0) + "1" + >>> format_value(1.234) + "1.23" + """ if not issubclass(type(s), str): s = format_str % s s = re.sub(r"\.?0+$", "", s) @@ -16,13 +34,34 @@ def format_value(s, format_str): return s -def format_labels(feature_mapping, feature_tuple): +def format_labels( + feature_mapping: dict[int, str], + feature_tuple: tuple[int, ...], +) -> str: + """Formats the feature labels for the plots. + + Args: + feature_mapping: A dictionary mapping feature indices to feature names. + feature_tuple: The feature tuple to be formatted. + + Returns: + str: The formatted feature tuple. + + Example: + >>> feature_mapping = {0: "A", 1: "B", 2: "C"} + >>> format_labels(feature_mapping, (0, 1)) + "A x B" + >>> format_labels(feature_mapping, (0,)) + "A" + >>> format_labels(feature_mapping, ()) + "Base Value" + """ if len(feature_tuple) == 0: - return "Baseval." + return "Base Value" elif len(feature_tuple) == 1: return str(feature_mapping[feature_tuple[0]]) else: - return " x ".join([feature_mapping[f] for f in feature_tuple]) + return " x ".join([str(feature_mapping[f]) for f in feature_tuple]) def abbreviate_feature_names(feature_names: Iterable[str]) -> list[str]: diff --git a/tests/test_base_interaction_values.py b/tests/test_base_interaction_values.py index 62e4bd9e..a93f066f 100644 --- a/tests/test_base_interaction_values.py +++ b/tests/test_base_interaction_values.py @@ -679,3 +679,50 @@ def test_aggregation(aggregation): elif aggregation == "min": expected_value = np.min(aggregated_value) assert aggregated_interaction_values[interaction] == expected_value + + # test aggregate from InteractionValues object + aggregated_from_object = interaction_values_list[0].aggregate( + aggregation=aggregation, others=interaction_values_list[1:] + ) + assert isinstance(aggregated_from_object, InteractionValues) + assert aggregated_from_object == aggregated_interaction_values # same values + assert aggregated_from_object is not aggregated_interaction_values # but different objects + + +def test_docs_aggregation_function(): + """Tests the aggregation function in the InteractionValues dataclass like in the docs.""" + + iv1 = InteractionValues( + values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]), + index="SII", + n_players=3, + min_order=1, + max_order=2, + interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5}, + baseline_value=0.0, + ) + + # this does not contain the (1, 2) interaction (i.e. is 0) + iv2 = InteractionValues( + values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]), + index="SII", + n_players=3, + min_order=1, + max_order=2, + interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4}, + baseline_value=1.0, + ) + + # test sum + aggregated_interaction_values = aggregate_interaction_values([iv1, iv2], aggregation="sum") + assert pytest.approx(aggregated_interaction_values[(0,)]) == 0.3 + assert pytest.approx(aggregated_interaction_values[(1,)]) == 0.5 + assert pytest.approx(aggregated_interaction_values[(1, 2)]) == 0.6 + assert pytest.approx(aggregated_interaction_values.baseline_value) == 1.0 + + # test mean + aggregated_interaction_values = aggregate_interaction_values([iv1, iv2], aggregation="mean") + assert pytest.approx(aggregated_interaction_values[(0,)]) == 0.15 + assert pytest.approx(aggregated_interaction_values[(1,)]) == 0.25 + assert pytest.approx(aggregated_interaction_values[(1, 2)]) == 0.3 + assert pytest.approx(aggregated_interaction_values.baseline_value) == 0.5 diff --git a/tests/tests_plots/test_utils.py b/tests/tests_plots/test_utils.py new file mode 100644 index 00000000..243963d8 --- /dev/null +++ b/tests/tests_plots/test_utils.py @@ -0,0 +1,35 @@ +"""This test module tests all plotting utilities.""" + +from shapiq.plot.utils import abbreviate_feature_names, format_labels, format_value + + +def test_format_value(): + """Test the format_value function.""" + assert format_value(1.0) == "1" + assert format_value(1.234) == "1.23" + assert format_value(-1.234) == "\u22121.23" + assert format_value("1.234") == "1.234" + + +def test_format_labels(): + """Test the format_labels function.""" + feature_mapping = {0: "A", 1: "B", 2: "C"} + assert format_labels(feature_mapping, (0, 1)) == "A x B" + assert format_labels(feature_mapping, (0,)) == "A" + assert format_labels(feature_mapping, ()) == "Base Value" + assert format_labels(feature_mapping, (0, 1, 2)) == "A x B x C" + + +def test_abbreviate_feature_names(): + """Tests the abbreviate_feature_names function.""" + # check for splitting characters + feature_names = ["feature-0", "feature_1", "feature 2", "feature.3"] + assert abbreviate_feature_names(feature_names) == ["F0", "F1", "F2", "F3"] + + # check for long names + feature_names = ["longfeaturenamethatisnotshort", "stilllong"] + assert abbreviate_feature_names(feature_names) == ["lon.", "sti."] + + # check for abbreviation with capital letters + feature_names = ["LongFeatureName", "Short"] + assert abbreviate_feature_names(feature_names) == ["LFN", "Sho."] From 28d2e83d104b776ef1f787c783b96a79b2637131 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 17:44:53 +0100 Subject: [PATCH 16/21] updated force_plot --- shapiq/plot/force.py | 183 +++++++++++++++++----------- tests/conftest.py | 34 ++++++ tests/tests_plots/test_bar.py | 41 +------ tests/tests_plots/test_force.py | 58 +++------ tests/tests_plots/test_waterfall.py | 60 ++------- 5 files changed, 179 insertions(+), 197 deletions(-) diff --git a/shapiq/plot/force.py b/shapiq/plot/force.py index 28613286..debb9a49 100644 --- a/shapiq/plot/force.py +++ b/shapiq/plot/force.py @@ -313,7 +313,7 @@ def _add_base_value(base_value: float, ax: plt.Axes) -> None: ax.add_line(line) text_out_val = ax.text( - base_value, 0.33, "base value", fontsize=12, alpha=1, horizontalalignment="center" + base_value, 0.25, "base value", fontsize=12, alpha=1, horizontalalignment="center" ) text_out_val.set_bbox(dict(facecolor="white", edgecolor="white")) @@ -375,37 +375,34 @@ def _split_features( feature_to_names: dict[int, str], out_value: float, ) -> tuple[np.ndarray, np.ndarray, float, float]: - """ - Split the features into positive and negative values. + """Splits the features into positive and negative values. + Args: - interaction_dictionary: Dictionary of the interaction values - """ - pos_features = np.array( - sorted( - [ - [str(value), format_labels(feature_to_names, coaltion)] - for coaltion, value in interaction_dictionary.items() - if value >= 0 and len(coaltion) > 0 - ], - key=lambda x: x[0], - reverse=True, - ), - dtype=object, - ) - neg_features = np.array( - sorted( - [ - [str(value), " x ".join([feature_to_names[f] for f in coaltion])] - for coaltion, value in interaction_dictionary.items() - if value < 0 and len(coaltion) > 0 - ], - key=lambda x: x[0], - reverse=True, - ), - dtype=object, - ) + interaction_dictionary: Dictionary containing the interaction values mapping from + feature indices to their values. + feature_to_names: Dictionary mapping feature indices to feature names. + out_value: The output value. - # Convert negative feature values to plot values + Returns: + tuple: A tuple containing the positive features, negative features, total positive value, + and total negative value. + """ + # split features into positive and negative values + pos_features, neg_features = [], [] + for coaltion, value in interaction_dictionary.items(): + if len(coaltion) == 0: + continue + label = format_labels(feature_to_names, coaltion) + if value >= 0: + pos_features.append([str(value), label]) + elif value < 0: + neg_features.append([str(value), label]) + pos_features = sorted(pos_features, key=lambda x: x[0], reverse=True) + neg_features = sorted(neg_features, key=lambda x: x[0], reverse=True) + pos_features = np.array(pos_features, dtype=object) + neg_features = np.array(neg_features, dtype=object) + + # convert negative feature values to plot values neg_val = out_value for i in neg_features: val = float(i[0]) @@ -418,7 +415,7 @@ def _split_features( else: total_neg = 0 - # Convert positive feature values to plot values + # convert positive feature values to plot values pos_val = out_value for i in pos_features: val = float(i[0]) @@ -472,53 +469,80 @@ def _add_bars( ax.add_patch(i) +def draw_higher_lower_element(out_value, offset_text): + plt.text( + out_value - offset_text, + 0.35, + "higher", + fontsize=13, + color="#FF0D57", + horizontalalignment="right", + ) + plt.text( + out_value + offset_text, + 0.35, + "lower", + fontsize=13, + color="#1E88E5", + horizontalalignment="left", + ) + plt.text( + out_value, 0.34, r"$\leftarrow$", fontsize=13, color="#1E88E5", horizontalalignment="center" + ) + plt.text( + out_value, + 0.36, + r"$\rightarrow$", + fontsize=13, + color="#FF0D57", + horizontalalignment="center", + ) + + def _draw_force_plot( interaction_value: InteractionValues, feature_names: np.ndarray, figsize: tuple[int, int], - show: bool = True, - text_rotation: float = 0, min_perc: float = 0.05, -): + draw_higher_lower: bool = True, +) -> plt.Figure: """ Draw the force plot. Args: - interaction_value: Interactiovalues ot be plotted - feature_names: names of the features - figsize: size of the figure - show: Whether to show the plot - text_rotation: Amount of text rotation - min_perc: Define the minimum percentage of the total effect that a feature must contribute to be shown. - Defaults to 0.05. + interaction_value: The ``InteractionValues`` to be plotted. + feature_names: Names of the features to be plotted provided as an array. + figsize: The size of the figure. + min_perc: Define the minimum percentage of the total effect that a feature must contribute + to be shown in the plot. Defaults to 0.05. Returns: None """ - # Turn off interactive plot - if show is False: - plt.ioff() + # turn off interactive plot + plt.ioff() - # Compute overall metrics + # compute overall metrics base_value = interaction_value.baseline_value out_value = np.sum(interaction_value.values) # Sum of all values with the baseline value - # Split features into positive and negative values + # split features into positive and negative values + features_to_names = {i: str(name) for i, name in enumerate(feature_names)} pos_features, neg_features, total_pos, total_neg = _split_features( - interaction_value.dict_values, {i: name for i, name in enumerate(feature_names)}, out_value + interaction_value.dict_values, features_to_names, out_value ) - # Define plots + # define plots offset_text = (np.abs(total_neg) + np.abs(total_pos)) * 0.04 fig, ax = plt.subplots(figsize=figsize) - # Compute axis limit + # compute axis limit update_axis_limits(ax, total_pos, pos_features, total_neg, neg_features, base_value, out_value) - # Add the bars to the plot + # add the bars to the plot _add_bars(ax, out_value, pos_features, neg_features) - # Add labels + # add labels total_effect = np.abs(total_neg) + total_pos fig, ax = _add_labels( fig, @@ -529,7 +553,7 @@ def _draw_force_plot( offset_text, total_effect, min_perc=min_perc, - text_rotation=text_rotation, + text_rotation=0, ) fig, ax = _add_labels( @@ -541,42 +565,63 @@ def _draw_force_plot( offset_text, total_effect, min_perc=min_perc, - text_rotation=text_rotation, + text_rotation=0, ) - # Add label for base value + # add higher and lower element + if draw_higher_lower: + draw_higher_lower_element(out_value, offset_text) + + # add label for base value _add_base_value(base_value, ax) - # Add output label + # add output label out_names = "" _add_output_element(out_names, out_value, ax) - if show: - plt.show() - else: - return plt.gcf() + # fix the whitespace around the plot + plt.tight_layout() + + return plt.gcf() def force_plot( interaction_values: InteractionValues, feature_names: Optional[np.ndarray] = None, - show: bool = False, abbreviate: bool = True, + show: bool = False, + figsize: tuple[int, int] = (15, 4), + draw_higher_lower: bool = True, + min_percentage: float = 0.05, ) -> Optional[plt.Figure]: - """ - Draw a force plot. + """Draws a force plot for the given interaction values. + Args: - interaction_values: - feature_names: - show: - abbreviate: + interaction_values: The ``InteractionValues`` to be plotted. + feature_names: The names of the features. If ``None``, the features are named by their index. + show: Whether to show or return the plot. Defaults to ``False`` and returns the plot. + abbreviate: Whether to abbreviate the feature names. Defaults to ``True.`` + figsize: The size of the figure. Defaults to ``(15, 4)``. + draw_higher_lower: Whether to draw the higher and lower indicator. Defaults to ``True``. + min_percentage: Define the minimum percentage of the total effect that a feature must contribute + to be shown in the plot. Defaults to 0.05. Returns: + plt.Figure: The figure of the plot """ if feature_names is None: - feature_names = np.array([str(i) for i in range(interaction_values.n_players)]) + feature_names = [str(i) for i in range(interaction_values.n_players)] if abbreviate: feature_names = abbreviate_feature_names(feature_names) - - return _draw_force_plot(interaction_values, feature_names, figsize=(20, 3), show=show) + feature_names = np.array(feature_names) + plot = _draw_force_plot( + interaction_values, + feature_names, + figsize=figsize, + draw_higher_lower=draw_higher_lower, + min_perc=min_percentage, + ) + if not show: + return plot + plt.show() diff --git a/tests/conftest.py b/tests/conftest.py index 7da6bfca..5614b438 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,40 @@ NR_FEATURES = 7 +@pytest.fixture +def cooking_game(): + """Return a simple game object.""" + import shapiq + + class CookingGame(shapiq.Game): + def __init__(self): + self.characteristic_function = { + (): 10, + (0,): 4, + (1,): 3, + (2,): 2, + (0, 1): 9, + (0, 2): 8, + (1, 2): 7, + (0, 1, 2): 15, + } + super().__init__( + n_players=3, + player_names=["Alice", "Bob", "Charlie"], # Optional list of names + normalization_value=self.characteristic_function[()], # 0 + normalize=False, + ) + + def value_function(self, coalitions: np.ndarray) -> np.ndarray: + """Defines the worth of a coalition as a lookup in the characteristic function.""" + output = [] + for coalition in coalitions: + output.append(self.characteristic_function[tuple(np.where(coalition)[0])]) + return np.array(output) + + return CookingGame() + + @pytest.fixture def dt_reg_model() -> DecisionTreeRegressor: """Return a simple decision tree model.""" diff --git a/tests/tests_plots/test_bar.py b/tests/tests_plots/test_bar.py index befff43c..71eac042 100644 --- a/tests/tests_plots/test_bar.py +++ b/tests/tests_plots/test_bar.py @@ -3,52 +3,19 @@ import matplotlib.pyplot as plt import numpy as np -import shapiq -from shapiq.interaction_values import InteractionValues -from shapiq.plot import bar_plot +from shapiq import ExactComputer, InteractionValues, bar_plot -def test_bar_concrete(): - class CookingGame(shapiq.Game): - def __init__(self): - self.characteristic_function = { - (): 10, - (0,): 4, - (1,): 3, - (2,): 2, - (0, 1): 9, - (0, 2): 8, - (1, 2): 7, - (0, 1, 2): 15, - } - super().__init__( - n_players=3, - player_names=["Alice", "Bob", "Charlie"], # Optional list of names - normalization_value=self.characteristic_function[()], # 0 - normalize=False, - ) - - def value_function(self, coalitions: np.ndarray) -> np.ndarray: - """Defines the worth of a coalition as a lookup in the characteristic function.""" - output = [] - for coalition in coalitions: - output.append(self.characteristic_function[tuple(np.where(coalition)[0])]) - return np.array(output) - - cooking_game = CookingGame() - - from shapiq import ExactComputer +def test_bar_cooking_game(cooking_game): + """Test the bar plot function with concrete values from the cooking game.""" # create an ExactComputer object for the cooking game - exact_computer = ExactComputer(n_players=cooking_game.n_players, game_fun=cooking_game) + exact_computer = ExactComputer(n_players=cooking_game.n_players, game=cooking_game) # compute the Shapley Values for the game sv_exact = exact_computer(index="k-SII", order=2) print(sv_exact.dict_values) - # visualize the Shapley Values - from shapiq.plot import bar_plot - bar_plot([sv_exact], show=True) diff --git a/tests/tests_plots/test_force.py b/tests/tests_plots/test_force.py index 75cc00ce..2837066c 100644 --- a/tests/tests_plots/test_force.py +++ b/tests/tests_plots/test_force.py @@ -3,53 +3,23 @@ import matplotlib.pyplot as plt import numpy as np -import shapiq -from shapiq.interaction_values import InteractionValues -from shapiq.plot import force_plot +from shapiq import ExactComputer, InteractionValues, force_plot -def test_force_concret(): - class CookingGame(shapiq.Game): - def __init__(self): - self.characteristic_function = { - (): 10, - (0,): 4, - (1,): 3, - (2,): 2, - (0, 1): 9, - (0, 2): 8, - (1, 2): 7, - (0, 1, 2): 15, - } - super().__init__( - n_players=3, - player_names=["Alice", "Bob", "Charlie"], # Optional list of names - normalization_value=self.characteristic_function[()], # 0 - normalize=False, - ) - - def value_function(self, coalitions: np.ndarray) -> np.ndarray: - """Defines the worth of a coalition as a lookup in the characteristic function.""" - output = [] - for coalition in coalitions: - output.append(self.characteristic_function[tuple(np.where(coalition)[0])]) - return np.array(output) - - cooking_game = CookingGame() - - from shapiq import ExactComputer - - # create an ExactComputer object for the cooking game - exact_computer = ExactComputer(n_players=cooking_game.n_players, game_fun=cooking_game) - - # compute the Shapley Values for the game - sv_exact = exact_computer(index="k-SII", order=2) - print(sv_exact.dict_values) - - # visualize the Shapley Values - from shapiq.plot import force_plot +def test_force_cooking_game(cooking_game): + """Test the force plot function with concrete values from the cooking game.""" + exact_computer = ExactComputer(n_players=cooking_game.n_players, game=cooking_game) + interaction_values = exact_computer(index="k-SII", order=2) + print(interaction_values.dict_values) + force_plot(interaction_values, show=True, min_percentage=0.2) + plt.close() - force_plot(sv_exact, show=True) + # visual inspection: + # - E[f(X)] = 10 + # - f(x) = 15 + # - 0, 1, and 2 should individually have negative contributions (go left) + # - all interactions should have a positive +7 contribution (go right) + # - feature 0 is too small to be displayed because of min_percentage=0.2 def test_force_plot(interaction_values_list: list[InteractionValues]): diff --git a/tests/tests_plots/test_waterfall.py b/tests/tests_plots/test_waterfall.py index 9bd4c41d..cd06d7f9 100644 --- a/tests/tests_plots/test_waterfall.py +++ b/tests/tests_plots/test_waterfall.py @@ -3,56 +3,22 @@ import matplotlib.pyplot as plt import numpy as np -from shapiq.interaction_values import InteractionValues -from shapiq.plot import waterfall_plot +from shapiq import ExactComputer, InteractionValues, waterfall_plot -def test_waterfall_concrete(): - import numpy as np - - import shapiq - - class CookingGame(shapiq.Game): - def __init__(self): - self.characteristic_function = { - (): 10, - (0,): 4, - (1,): 3, - (2,): 2, - (0, 1): 9, - (0, 2): 8, - (1, 2): 7, - (0, 1, 2): 15, - } - super().__init__( - n_players=3, - player_names=["Alice", "Bob", "Charlie"], # Optional list of names - normalization_value=self.characteristic_function[()], # 0 - normalize=False, - ) - - def value_function(self, coalitions: np.ndarray) -> np.ndarray: - """Defines the worth of a coalition as a lookup in the characteristic function.""" - output = [] - for coalition in coalitions: - output.append(self.characteristic_function[tuple(np.where(coalition)[0])]) - return np.array(output) - - cooking_game = CookingGame() - - from shapiq import ExactComputer - - # create an ExactComputer object for the cooking game - exact_computer = ExactComputer(n_players=cooking_game.n_players, game_fun=cooking_game) - - # compute the Shapley Values for the game - sv_exact = exact_computer(index="k-SII", order=2) - print(sv_exact.dict_values) - - # visualize the Shapley Values - from shapiq.plot import waterfall_plot +def test_waterfall_cooking_game(cooking_game): + """Test the waterfall plot function with concrete values from the cooking game.""" + exact_computer = ExactComputer(n_players=cooking_game.n_players, game=cooking_game) + interaction_values = exact_computer(index="k-SII", order=2) + print(interaction_values.dict_values) + waterfall_plot(interaction_values, show=True) + plt.close() - waterfall_plot(sv_exact, show=True) + # visual inspection: + # - E[f(X)] = 10 + # - f(x) = 15 + # - 0, 1, and 2 should individually have negative contributions (go left) + # - all interactions should have a positive +7 contribution (go right) def test_waterfall_plot(interaction_values_list: list[InteractionValues]): From 4113d174597f845f89eba7a8ffc5f8e7587da718 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 17:47:27 +0100 Subject: [PATCH 17/21] updated test for the bar plot --- tests/tests_plots/test_bar.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/tests_plots/test_bar.py b/tests/tests_plots/test_bar.py index 71eac042..57ada54a 100644 --- a/tests/tests_plots/test_bar.py +++ b/tests/tests_plots/test_bar.py @@ -8,16 +8,14 @@ def test_bar_cooking_game(cooking_game): """Test the bar plot function with concrete values from the cooking game.""" - - # create an ExactComputer object for the cooking game exact_computer = ExactComputer(n_players=cooking_game.n_players, game=cooking_game) - - # compute the Shapley Values for the game sv_exact = exact_computer(index="k-SII", order=2) print(sv_exact.dict_values) - bar_plot([sv_exact], show=True) + # visual inspection: + # - Order from top to bottom: Base Value, the interactions (all equal), F0, F1, F2 + def test_bar_plot(interaction_values_list: list[InteractionValues]): """Test the bar plot function.""" From 0fb57dfe05c4176f7bcf3e344a5670a360ec6bd8 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 17:51:56 +0100 Subject: [PATCH 18/21] updated force plot test --- tests/tests_plots/test_force.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_plots/test_force.py b/tests/tests_plots/test_force.py index 2837066c..8cfd128a 100644 --- a/tests/tests_plots/test_force.py +++ b/tests/tests_plots/test_force.py @@ -11,7 +11,8 @@ def test_force_cooking_game(cooking_game): exact_computer = ExactComputer(n_players=cooking_game.n_players, game=cooking_game) interaction_values = exact_computer(index="k-SII", order=2) print(interaction_values.dict_values) - force_plot(interaction_values, show=True, min_percentage=0.2) + feature_names = list(cooking_game.player_name_lookup.keys()) + force_plot(interaction_values, show=True, min_percentage=0.2, feature_names=feature_names) plt.close() # visual inspection: From 31283a15802265b9eca623fb6545cfdf4ac43224 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 18:08:14 +0100 Subject: [PATCH 19/21] updated tests for plots --- shapiq/plot/watefall.py | 17 +++++++---------- tests/conftest.py | 4 +++- tests/tests_plots/test_bar.py | 6 ++++++ 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/shapiq/plot/watefall.py b/shapiq/plot/watefall.py index fe286da7..4b8a19db 100644 --- a/shapiq/plot/watefall.py +++ b/shapiq/plot/watefall.py @@ -220,7 +220,7 @@ def _draw_waterfall_plot( if text_bbox.width > arrow_bbox.width: txt_obj.remove() - txt_obj = plt.text( + plt.text( neg_lefts[i] - (5 / 72) * bbox_to_xscale + dist, neg_inds[i], format_value(neg_widths[i], "%+0.02f"), @@ -346,15 +346,12 @@ def waterfall_plot( feature_names = abbreviate_feature_names(feature_names) feature_mapping = {i: feature_names[i] for i in range(interaction_values.n_players)} - data = np.array( - [ - (format_labels(feature_mapping, feature_tuple), str(value)) - for feature_tuple, value in interaction_values.dict_values.items() - if len(feature_tuple) > 0 - ], - dtype=object, - ) - + # create the data for the waterfall plot in the correct format + data = [] + for feature_tuple, value in interaction_values.dict_values.items(): + if len(feature_tuple) > 0: + data.append((format_labels(feature_mapping, feature_tuple), str(value))) + data = np.array(data, dtype=object) values = data[:, 1].astype(float) feature_names = data[:, 0] diff --git a/tests/conftest.py b/tests/conftest.py index 5614b438..b8c5a40e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -360,6 +360,8 @@ def mae_loss(): @pytest.fixture def interaction_values_list(): """Returns a list of three InteractionValues objects.""" + rng = np.random.RandomState(42) + from shapiq.interaction_values import InteractionValues from shapiq.utils import powerset @@ -375,7 +377,7 @@ def interaction_values_list(): powerset(range(n_players), min_size=min_order, max_size=max_order) ): interaction_lookup[interaction] = i - values.append(np.random.rand()) + values.append(rng.uniform(0, 1)) values = np.array(values) iv = InteractionValues( n_players=n_players, diff --git a/tests/tests_plots/test_bar.py b/tests/tests_plots/test_bar.py index 57ada54a..16109109 100644 --- a/tests/tests_plots/test_bar.py +++ b/tests/tests_plots/test_bar.py @@ -43,3 +43,9 @@ def test_bar_plot(interaction_values_list: list[InteractionValues]): assert output is not None assert isinstance(output, plt.Axes) plt.close("all") + + # test global = false + output = bar_plot(interaction_values_list, show=False, global_plot=False) + assert output is not None + assert isinstance(output, plt.Axes) + plt.close("all") From ddcf5c7b26150e2f8a30169bccb2a1b722203744 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 19:02:27 +0100 Subject: [PATCH 20/21] removes shap from requirements --- requirements.txt | 1 - tests/requirements/requirements.txt | 1 - 2 files changed, 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index b9bf4c2a..3b7c68ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,6 @@ ruff==0.8.4 scikit-image==0.25.0 scikit-learn==1.6.0 scipy==1.14.1 -shap==0.46.0 tqdm==4.67.1 torch==2.5.1 torchvision==0.20.1 diff --git a/tests/requirements/requirements.txt b/tests/requirements/requirements.txt index 303ab789..7a750393 100644 --- a/tests/requirements/requirements.txt +++ b/tests/requirements/requirements.txt @@ -11,7 +11,6 @@ ruff==0.6.2 scikit-image==0.24.0 scikit-learn==1.5.1 scipy==1.13.0 -shap==0.46.0 tqdm==4.66.5 torch==2.4.0 torchvision==0.19.0 From 7fcafa2fa77cf079284f5a5ffca53506a719f69c Mon Sep 17 00:00:00 2001 From: Maximilian Date: Fri, 10 Jan 2025 19:12:52 +0100 Subject: [PATCH 21/21] removed call to shap in test --- .../tests_explainer/test_explainer_tabular.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/tests_explainer/test_explainer_tabular.py b/tests/tests_explainer/test_explainer_tabular.py index 829eb586..0d5f33bf 100644 --- a/tests/tests_explainer/test_explainer_tabular.py +++ b/tests/tests_explainer/test_explainer_tabular.py @@ -176,25 +176,34 @@ def test_explain(dt_model, data, index, budget, max_order, imputer): def test_against_shap_linear(): """Tests weather TabularExplainer yields similar results as SHAP with a basic linear model.""" - import shap n_samples = 3 dim = 5 + rng = np.random.default_rng(42) def make_linear_model(): - w = np.random.default_rng().normal(size=dim) + w = rng.normal(size=dim) def model(X: np.ndarray): return np.dot(X, w) return model - X = np.random.default_rng().normal(size=(n_samples, dim)) + X = rng.normal(size=(n_samples, dim)) model = make_linear_model() + # import shap # compute with shap - explainer_shap = shap.explainers.Exact(model, X) - shap_values = explainer_shap(X).values + # explainer_shap = shap.explainers.Exact(model, X) + # shap_values = explainer_shap(X).values + # print(shap_values) + shap_values = np.array( + [ + [-0.29565839, -0.36698085, -0.55970434, 0.22567077, 0.05852208], + [1.08513574, 0.06365536, 0.46312977, -0.61532757, 0.00370387], + [-0.78947735, 0.30332549, 0.09657457, 0.38965679, -0.06222595], + ] + ) # compute with shapiq explainer_shapiq = TabularExplainer(