Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix data being displayed in waterfall plots #35

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions examples/plot_basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import matplotlib.pyplot as plt
from sklearn.utils import check_random_state
from sharp import ShaRP
from sharp.utils import scores_to_ordering

# Set up some envrionment variables
RNG_SEED = 42
Expand All @@ -31,6 +32,7 @@ def score_function(X):
[rng.normal(size=(N_SAMPLES, 1)), rng.binomial(1, 0.5, size=(N_SAMPLES, 1))], axis=1
)
y = score_function(X)
rank = scores_to_ordering(y)


######################################################################################
Expand Down Expand Up @@ -65,17 +67,29 @@ def score_function(X):
######################################################################################
# We can also turn these into visualizations:

plt.style.use('seaborn-v0_8-whitegrid')

# Visualization of feature contributions
print("Sample 2 feature values:", X[2])
print("Sample 3 feature values:", X[3])
fig, axes = plt.subplots(1, 2)
fig, axes = plt.subplots(1, 2, figsize=(13.5, 4.5), layout="constrained")

# Bar plot comparing two points
xai.plot.bar(pair_scores, ax=axes[0])
axes[0].set_title("Pairwise comparison (Sample 2 vs 3)")
xai.plot.bar(pair_scores, ax=axes[0], color="#ff0051")
axes[0].set_title(
f"Pairwise comparison - Sample 2 (rank {rank[2]}) vs 3 (rank {rank[3]})",
fontsize=12,
y=-0.2
)
axes[0].set_xlabel("")
axes[0].set_ylabel("Contribution to rank", fontsize=12)
axes[0].tick_params(axis='both', which='major', labelsize=12)

# Waterfall explaining rank for sample 2
axes[1] = xai.plot.waterfall(individual_scores)
axes[1].suptitle("Rank explanation for Sample 9")
axes[1] = xai.plot.waterfall(
individual_scores, feature_values=X[9], mean_target_value=rank.mean()
)
ax = axes[1].gca()
ax.set_title("Rank explanation for Sample 9", fontsize=12, y=-0.2)

plt.show()
2 changes: 1 addition & 1 deletion sharp/utils/_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def check_feature_names(X):
feature_names = _get_feature_names(X)

if feature_names is None:
feature_names = np.indices([X.shape[1]]).squeeze()
feature_names = np.array([f"Feature {i}" for i in range(X.shape[1])])

return feature_names

Expand Down
8 changes: 4 additions & 4 deletions sharp/visualization/_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def bar(self, scores, ax=None, **kwargs):

return ax

def waterfall(self, scores, mean_shapley_value=0):
def waterfall(self, contributions, feature_values=None, mean_target_value=0):
"""
TODO: refactor waterfall plot code.
"""
Expand All @@ -37,10 +37,10 @@ def waterfall(self, scores, mean_shapley_value=0):
rank_dict = {
"upper_bounds": None,
"lower_bounds": None,
"features": None, # pd.Series(feature_names),
"features": feature_values, # pd.Series(feature_names),
"data": None, # pd.Series(ind_values, index=feature_names),
"base_values": mean_shapley_value,
"base_values": mean_target_value,
"feature_names": feature_names,
"values": pd.Series(scores, index=feature_names),
"values": pd.Series(contributions, index=feature_names),
}
return _waterfall(rank_dict, max_display=10)
25 changes: 16 additions & 9 deletions sharp/visualization/_waterfall.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import pandas as pd
import numpy as np
from sklearn.utils.validation import check_array
from sharp.utils._utils import _optional_import


Expand Down Expand Up @@ -30,15 +31,19 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa
plt.ioff()

base_values = float(shap_values["base_values"])
features = shap_values["values"]
features = (
np.array(shap_values["features"])
if shap_values["features"] is not None
else np.array(shap_values["values"])
)
feature_names = shap_values["feature_names"]
# lower_bounds = shap_values["lower_bounds"]
# upper_bounds = shap_values["upper_bounds"]
values = shap_values["values"]

# init variables we use for tracking the plot locations
num_features = min(max_display, len(values))
row_height = 0.5
# row_height = 0.5
rng = range(num_features - 1, -1, -1)
order = np.argsort(-np.abs(values))
pos_lefts = []
Expand All @@ -55,7 +60,7 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa
yticklabels = ["" for _ in range(num_features + 1)]

# size the plot based on how many features we are plotting
plt.gcf().set_size_inches(8, num_features * row_height + 1.5)
# plt.gcf().set_size_inches(8, num_features * row_height + 1.5)

# see how many individual (vs. grouped at the end) features we are plotting
if num_features == len(values):
Expand All @@ -66,7 +71,7 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa
# compute the locations of the individual features and plot the dashed connecting
# lines
for i in range(num_individual):
sval = values[order[i]]
sval = values.iloc[order.iloc[i]]
loc -= sval
if sval >= 0:
pos_inds.append(rng[i])
Expand All @@ -92,17 +97,19 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa
zorder=-1,
)
if features is None:
yticklabels[rng[i]] = feature_names[order[i]]
yticklabels[rng[i]] = feature_names[order.iloc[i]]
else:
if np.issubdtype(type(features[order[i]]), np.number):
if np.issubdtype(type(features[order.iloc[i]]), np.number):
yticklabels[rng[i]] = (
format_value(float(features[order[i]]), "%0.03f")
format_value(float(features[order.iloc[i]]), "%0.03f")
+ " = "
+ feature_names[order[i]]
+ feature_names[order.iloc[i]]
)
else:
yticklabels[rng[i]] = (
str(features[order[i]]) + " = " + str(feature_names[order[i]])
str(features[order.iloc[i]])
+ " = "
+ str(feature_names[order.iloc[i]])
)

# add a last grouped feature to represent the impact of all the features we didn't
Expand Down
Loading