From 4492710e4f5bed335f376815bbb5edde67baa077 Mon Sep 17 00:00:00 2001 From: Christiaan Meijer Date: Wed, 4 Dec 2024 18:59:26 +0100 Subject: [PATCH 1/2] Update kernelshap_tabular_land_atmosphere.ipynb --- .../KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb b/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb index 15ab6fcb..925edefa 100644 --- a/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb +++ b/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb @@ -1182,6 +1182,9 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" + }, + "execution": { + "timeout": 1800 } }, "nbformat": 4, From d92c03b2fc8ae2bd69ec674343c01467f1ab52e8 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Thu, 19 Dec 2024 14:20:53 +0100 Subject: [PATCH 2/2] introducing key locally_run to disable time consuming cell for github action --- .../kernelshap_tabular_land_atmosphere.ipynb | 162 ++++++++++-------- 1 file changed, 93 insertions(+), 69 deletions(-) diff --git a/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb b/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb index 925edefa..2570f304 100644 --- a/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb +++ b/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb @@ -720,7 +720,7 @@ }, { "cell_type": "markdown", - "id": "80411a9d-881e-4196-8559-17aaadd15841", + "id": "ddb1e4f0-2674-4242-bcd8-abf66f97c611", "metadata": {}, "source": [ "#### 5 - Run the explainer at one location, several data instances (here as an example one month time series)\n", @@ -805,6 +805,24 @@ "background_data = x_train.drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()" ] }, + { + "cell_type": "markdown", + "id": "8b612e55-e1ec-40dc-b189-65d90ffb2b1c", + "metadata": {}, + "source": [ + "This step takes a few minutes, so not suitable for github actions. If you want to run this step locally, set `locally_run = True`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "59a54eaa-f6f2-42b5-8849-aceb37b06156", + "metadata": {}, + "outputs": [], + "source": [ + "locally_run = False" + ] + }, { "cell_type": "code", "execution_count": 14, @@ -821,11 +839,12 @@ ], "source": [ "# run explainer over time series, this might take a few minutes\n", - "explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n", - " mode ='regression', training_data=background_data, training_data_kmeans=5,\n", - " feature_names=features.columns, silent=True)\n", - "\n", - "print(\"Dianna is done!\") " + "if locally_run:\n", + " explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n", + " mode ='regression', training_data=background_data, training_data_kmeans=5,\n", + " feature_names=features.columns, silent=True)\n", + " \n", + " print(\"Dianna is done!\") " ] }, { @@ -846,30 +865,31 @@ } ], "source": [ - "# create shap_values object\n", - "shap_values = Explanation(explanations[key])\n", - "shap_values.feature_names = features.columns\n", - "\n", - "# create comparison plot: predictions vs test data \n", - "y_predict_time = runner(features.to_numpy())\n", - "y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n", - "comparison_plot(y_test_time, y_predict_time, show=False) \n", - "comparison_img = plt.gcf()\n", - "plt.close()\n", - "\n", - "# create summary plot\n", - "shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n", - "summary_img = plt.gcf()\n", - "plt.close()\n", - "\n", - "# create heatmap plot\n", - "shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n", - "heatmap_img = plt.gcf()\n", - "plt.close()\n", - "\n", - "# plot all three figures in one cell\n", - "figures = [comparison_img, heatmap_img, summary_img]\n", - "display_figures(figures, captions, 1, 3)" + "if locally_run:\n", + " # create shap_values object\n", + " shap_values = Explanation(explanations[key])\n", + " shap_values.feature_names = features.columns\n", + " \n", + " # create comparison plot: predictions vs test data \n", + " y_predict_time = runner(features.to_numpy())\n", + " y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n", + " comparison_plot(y_test_time, y_predict_time, show=False) \n", + " comparison_img = plt.gcf()\n", + " plt.close()\n", + " \n", + " # create summary plot\n", + " shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n", + " summary_img = plt.gcf()\n", + " plt.close()\n", + " \n", + " # create heatmap plot\n", + " shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n", + " heatmap_img = plt.gcf()\n", + " plt.close()\n", + " \n", + " # plot all three figures in one cell\n", + " figures = [comparison_img, heatmap_img, summary_img]\n", + " display_figures(figures, captions, 1, 3)" ] }, { @@ -887,9 +907,10 @@ } ], "source": [ - "relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n", - "cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n", - "print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")" + "if locally_run:\n", + " relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n", + " cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n", + " print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")" ] }, { @@ -947,12 +968,13 @@ } ], "source": [ - "# run explainer over time series, this might take a few minutes\n", - "explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n", - " mode ='regression', training_data=background_data, training_data_kmeans=5,\n", - " feature_names=features.columns, silent=True)\n", - "\n", - "print(\"Dianna is done!\") " + "if locally_run:\n", + " # run explainer over time series, this might take a few minutes\n", + " explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n", + " mode ='regression', training_data=background_data, training_data_kmeans=5,\n", + " feature_names=features.columns, silent=True)\n", + " \n", + " print(\"Dianna is done!\") " ] }, { @@ -973,30 +995,31 @@ } ], "source": [ - "# create shap_values object\n", - "shap_values = Explanation(explanations[key])\n", - "shap_values.feature_names = features.columns\n", - "\n", - "# create comparison plot: predictions vs test data \n", - "y_predict_time = runner(features.to_numpy())\n", - "y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n", - "comparison_plot(y_test_time, y_predict_time, show=False) \n", - "comparison_img = plt.gcf()\n", - "plt.close()\n", - "\n", - "# create summary plot\n", - "shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n", - "summary_img = plt.gcf()\n", - "plt.close()\n", - "\n", - "# create heatmap plot\n", - "shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n", - "heatmap_img = plt.gcf()\n", - "plt.close()\n", - "\n", - "# plot all three figures in one cell\n", - "figures = [comparison_img, heatmap_img, summary_img]\n", - "display_figures(figures, captions, 1, 3)" + "if locally_run:\n", + " # create shap_values object\n", + " shap_values = Explanation(explanations[key])\n", + " shap_values.feature_names = features.columns\n", + " \n", + " # create comparison plot: predictions vs test data \n", + " y_predict_time = runner(features.to_numpy())\n", + " y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n", + " comparison_plot(y_test_time, y_predict_time, show=False) \n", + " comparison_img = plt.gcf()\n", + " plt.close()\n", + " \n", + " # create summary plot\n", + " shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n", + " summary_img = plt.gcf()\n", + " plt.close()\n", + " \n", + " # create heatmap plot\n", + " shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n", + " heatmap_img = plt.gcf()\n", + " plt.close()\n", + " \n", + " # plot all three figures in one cell\n", + " figures = [comparison_img, heatmap_img, summary_img]\n", + " display_figures(figures, captions, 1, 3)" ] }, { @@ -1014,9 +1037,10 @@ } ], "source": [ - "relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n", - "cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n", - "print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")" + "if locally_run:\n", + " relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n", + " cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n", + " print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")" ] }, { @@ -1166,6 +1190,9 @@ } ], "metadata": { + "execution": { + "timeout": 1800 + }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", @@ -1182,9 +1209,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" - }, - "execution": { - "timeout": 1800 } }, "nbformat": 4,