Skip to content

Commit

Permalink
add tutorial notebook for regression
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang committed Dec 12, 2023
1 parent 2d83bdd commit 765dd67
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 9 deletions.
7 changes: 6 additions & 1 deletion dianna/methods/kernelshap_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class KERNELSHAPTabular:
def __init__(
self,
training_data: np.array,
mode: str = "classification",
feature_names: List[int] = None,
training_data_kmeans: Optional[int] = None,
) -> None:
Expand All @@ -26,6 +27,7 @@ def __init__(
Arguments:
training_data (np.array): training data, which should be numpy 2d array
mode (str, optional): "classification" or "regression"
feature_names (list(str), optional): list of names corresponding to the columns
in the training data.
training_data_kmeans(int, optional): summarize the whole training set with
Expand All @@ -36,7 +38,7 @@ def __init__(
else:
self.training_data = training_data
self.feature_names = feature_names

self.mode = mode
self.explainer: KernelExplainer

def explain(
Expand Down Expand Up @@ -76,4 +78,7 @@ def explain(

saliency = self.explainer.shap_values(input_tabular, **explain_instance_kwargs)

if self.mode == 'regression':
return saliency[0]

return saliency
4 changes: 2 additions & 2 deletions dianna/methods/lime_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def explain(
**explain_instance_kwargs,
)

if self.mode == "regression":
if self.mode == 'regression':
local_exp = sorted(explanation.local_exp[1])
saliency = [i[1] for i in local_exp]

elif self.mode == "classification":
elif self.mode == 'classification':
# extract scores from lime explainer
saliency = []
for i in range(self.top_labels):
Expand Down
10 changes: 4 additions & 6 deletions tutorials/kernelshap_tabular_penguin.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
"<img width=\"150\" alt=\"Logo_ER10\" src=\"https://user-images.githubusercontent.com/3244249/151994514-b584b984-a148-4ade-80ee-0f88b0aefa45.png\">\n",
"\n",
"### Model Interpretation using KernelSHAP for penguin dataset classifier\n",
"This notebook demonstrates the use of DIANNA with the SHAP Kernel explainer tabular method on the penguins dataset.\n",
"\n",
"https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/model_agnostic/Census%20income%20classification%20with%20scikit-learn.html"
"This notebook demonstrates the use of DIANNA with the SHAP Kernel explainer tabular method on the penguins dataset."
]
},
{
Expand Down Expand Up @@ -348,7 +346,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3. Applying LIME with DIANNA\n",
"#### 3. Applying KernelSHAP with DIANNA\n",
"The simplest way to run DIANNA on image data is with `dianna.explain_tabular`.\n",
"\n",
"DIANNA requires input in numpy format, so the input data is converted into a numpy array.\n",
Expand All @@ -373,8 +371,8 @@
],
"source": [
"explanation = dianna.explain_tabular(run_model, input_tabular=data_instance, method='kernelshap',\n",
" training_data = X_train, training_data_kmeans = 5,\n",
" feature_names=input_features.columns)"
" mode ='classification', training_data = X_train,\n",
" training_data_kmeans = 5, feature_names=input_features.columns)"
]
},
{
Expand Down
302 changes: 302 additions & 0 deletions tutorials/kernelshap_tabular_weather.ipynb

Large diffs are not rendered by default.

0 comments on commit 765dd67

Please sign in to comment.