Skip to content

Commit

Permalink
Merge branch 'main' into 546-masking-time-step-segmentation
Browse files Browse the repository at this point in the history
# Conflicts:
#	tests/methods/test_rise_timeseries.py
  • Loading branch information
cwmeijer committed Jan 25, 2024
2 parents 728139a + dbf1e93 commit 463b91f
Show file tree
Hide file tree
Showing 39 changed files with 2,599 additions and 424 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.2.0
current_version = 1.3.0

[comment]
comment = The contents of this file cannot be merged with that of setup.cfg until https://github.com/c4urself/bump2version/issues/185 is resolved
Expand Down
4 changes: 2 additions & 2 deletions .github/actions/install-python-and-package/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ description: "Installs Python, updates pip and installs DIANNA together with its
inputs:
python-version:
required: false
description: "The Python version to use. Specify major and minor version, e.g. '3.9'."
default: "3.9"
description: "The Python version to use. Specify major and minor version, e.g. '3.10'."
default: "3.10"
extras-require:
required: false
description: "The extras dependencies packages to be installed, for instance 'docs' or 'publishing,notebooks'."
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
fail-fast: false
matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ['3.8', '3.11']
python-version: ['3.9', '3.11']
exclude:
# already tested in build_single job
- python-version: 3.11
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ on:
jobs:

notebooks:
name: Run notebooks on (3.9, ${{ matrix.os }})
name: Run notebooks on (3.10, ${{ matrix.os }})
if: github.event.pull_request.draft == false
runs-on: ${{ matrix.os }}
strategy:
Expand Down
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ authors:
name-particle: "van der"

doi: 10.5281/zenodo.5801485
version: "1.2.0"
version: "1.3.0"
repository-code: "https://github.com/dianna-ai/dianna"
keywords:
- XAI
Expand Down
60 changes: 47 additions & 13 deletions README.md

Large diffs are not rendered by default.

84 changes: 67 additions & 17 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,23 @@
"""
import importlib
import logging
import warnings
from . import utils

logging.getLogger(__name__).addHandler(logging.NullHandler())

__author__ = 'DIANNA Team'
__email__ = 'dianna-ai@esciencecenter.nl'
__version__ = '1.2.0'
__version__ = '1.3.0'


def explain_timeseries(model_or_function, timeseries_data, method, labels,
**kwargs):
def explain_timeseries(model_or_function, input_timeseries, method, labels, **kwargs):
"""Explain timeseries data given a model and a chosen method.
Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
timeseries_data (np.ndarray): Timeseries data to be explained
input_timeseries (np.ndarray): Timeseries data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int)): Labels to be explained
kwargs: key word arguments
Expand All @@ -49,18 +49,24 @@ def explain_timeseries(model_or_function, timeseries_data, method, labels,
"""
explainer = _get_explainer(method, kwargs, modality='Timeseries')
explain_timeseries_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs)
return explainer.explain(model_or_function, timeseries_data, labels,
**explain_timeseries_kwargs)
explainer.explain, kwargs
)
for key in explain_timeseries_kwargs.keys():
kwargs.pop(key)
if kwargs:
warnings.warn(message = f'Please note the following kwargs are not being used: {kwargs}')
return explainer.explain(
model_or_function, input_timeseries, labels, **explain_timeseries_kwargs
)


def explain_image(model_or_function, input_data, method, labels, **kwargs):
def explain_image(model_or_function, input_image, method, labels, **kwargs):
"""Explain an image (input_data) given a model and a chosen method.
Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_data (np.ndarray): Image data to be explained
input_image (np.ndarray): Image data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int)): Labels to be explained
kwargs: These keyword parameters are passed on
Expand All @@ -74,13 +80,18 @@ def explain_image(model_or_function, input_data, method, labels, **kwargs):
from onnx_tf.backend import prepare # noqa: F401
explainer = _get_explainer(method, kwargs, modality='Image')
explain_image_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs)
return explainer.explain(model_or_function, input_data, labels,
**explain_image_kwargs)
explainer.explain, kwargs
)
for key in explain_image_kwargs.keys():
kwargs.pop(key)
if kwargs:
warnings.warn(message = f'Please note the following kwargs are not being used: {kwargs}')
return explainer.explain(
model_or_function, input_image, labels, **explain_image_kwargs
)


def explain_text(model_or_function, input_text, tokenizer, method, labels,
**kwargs):
def explain_text(model_or_function, input_text, tokenizer, method, labels, **kwargs):
"""Explain text (input_text) given a model and a chosen method.
Args:
Expand All @@ -98,7 +109,12 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels,
"""
explainer = _get_explainer(method, kwargs, modality='Text')
explain_text_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs)
explainer.explain, kwargs
)
for key in explain_text_kwargs.keys():
kwargs.pop(key)
if kwargs:
warnings.warn(message = f'Please note the following kwargs are not being used: {kwargs}')
return explainer.explain(
model_or_function=model_or_function,
input_text=input_text,
Expand All @@ -108,10 +124,40 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels,
)


def explain_tabular(model_or_function, input_tabular, method, labels=(1, ), **kwargs):
"""Explain tabular (input_text) given a model and a chosen method.
Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_tabular (np.ndarray): Tabular data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int), optional): Labels to be explained
kwargs: These keyword parameters are passed on
Returns:
One heatmap (2D array) per class.
"""
explainer = _get_explainer(method, kwargs, modality='Tabular')
explain_tabular_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
for key in explain_tabular_kwargs.keys():
kwargs.pop(key)
if kwargs:
warnings.warn(message = f'Please note the following kwargs are not being used: {kwargs}')
return explainer.explain(
model_or_function=model_or_function,
input_tabular=input_tabular,
labels=labels,
**explain_tabular_kwargs,
)

def _get_explainer(method, kwargs, modality):
try:
method_submodule = importlib.import_module(
f'dianna.methods.{method.lower()}_{modality.lower()}')
f'dianna.methods.{method.lower()}_{modality.lower()}'
)
except ImportError as err:
raise ValueError(
f'Method {method.lower()}_{modality.lower()} does not exist'
Expand All @@ -123,5 +169,9 @@ def _get_explainer(method, kwargs, modality):
f'Data modality {modality} is not available for method {method.upper()}'
) from err
method_kwargs = utils.get_kwargs_applicable_to_function(
method_class.__init__, kwargs)
method_class.__init__, kwargs
)
# Remove used kwargs from list of kwargs passed to the function.
for key in method_kwargs.keys():
kwargs.pop(key)
return method_class(**method_kwargs)
2 changes: 1 addition & 1 deletion dianna/dashboard/_models_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def run_model(ts_data):

explanation = dianna.explain_timeseries(
run_model,
timeseries_data=ts_data[0],
input_timeseries=ts_data[0],
method='RISE',
**kwargs,
)
Expand Down
84 changes: 84 additions & 0 deletions dianna/methods/kernelshap_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import shap
from shap import KernelExplainer
from dianna import utils


class KERNELSHAPTabular:
"""Wrapper around the SHAP Kernel explainer for tabular data."""

def __init__(
self,
training_data: np.array,
mode: str = "classification",
feature_names: List[int] = None,
training_data_kmeans: Optional[int] = None,
) -> None:
"""Initializer of KERNELSHAPTabular.
Training data must be provided for the explainer to estimate the expected
values.
More information can be found in the API guide:
https://github.com/shap/shap/blob/master/shap/explainers/_kernel.py
Arguments:
training_data (np.array): training data, which should be numpy 2d array
mode (str, optional): "classification" or "regression"
feature_names (list(str), optional): list of names corresponding to the columns
in the training data.
training_data_kmeans(int, optional): summarize the whole training set with
weighted kmeans
"""
if training_data_kmeans:
self.training_data = shap.kmeans(training_data, training_data_kmeans)
else:
self.training_data = training_data
self.feature_names = feature_names
self.mode = mode
self.explainer: KernelExplainer

def explain(
self,
model_or_function: Union[str, callable],
input_tabular: np.array,
link: str = "identity",
**kwargs,
) -> np.array:
"""Run the KernelSHAP explainer.
Args:
model_or_function (callable or str): The function that runs the model to be explained
or the path to a ONNX model on disk.
input_tabular (np.ndarray): Data to be explained.
link (str): A generalized linear model link to connect the feature importance values
to the model. Must be either "identity" or "logit".
kwargs: These parameters are passed on
Other keyword arguments: see the documentation for KernelExplainer:
https://github.com/shap/shap/blob/master/shap/explainers/_kernel.py
Returns:
explanation: An Explanation object containing the KernelExplainer explanations
for each class.
"""
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
KernelExplainer, kwargs
)
self.explainer = KernelExplainer(
model_or_function, self.training_data, link, **init_instance_kwargs
)

explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.shap_values, kwargs
)

saliency = self.explainer.shap_values(input_tabular, **explain_instance_kwargs)

if self.mode == 'regression':
return saliency[0]

return saliency
Loading

0 comments on commit 463b91f

Please sign in to comment.