diff --git a/dianna/methods/lime_tabular.py b/dianna/methods/lime_tabular.py index b95e8d92..839e0f70 100644 --- a/dianna/methods/lime_tabular.py +++ b/dianna/methods/lime_tabular.py @@ -1,4 +1,6 @@ """LIME tabular explainer.""" +from typing import Iterable +from typing import Union import numpy as np from lime.lime_tabular import LimeTabularExplainer from dianna import utils @@ -9,18 +11,18 @@ class LIMETabular: def __init__( self, - training_data, - mode='classification', - feature_names=None, - categorical_features=None, - kernel_width=25, - kernel=None, - verbose=False, - class_names=None, - feature_selection='auto', - random_state=None, + training_data: np.array, + mode: str = "classification", + feature_names: list[int] = None, + categorical_features: list[int] = None, + kernel_width: int = 25, + kernel: callable = None, + verbose: bool = False, + class_names: list[str] = None, + feature_selection: str = "auto", + random_state: int = None, **kwargs, - ): + ) -> None: """Initializes Lime explainer. For numerical features, perturb them by sampling from a Normal(0,1) and @@ -37,9 +39,9 @@ def __init__( Args: training_data (np.array): numpy 2d array mode (str, optional): "classification" or "regression" - feature_names (strings, optional): list of names corresponding to the columns + feature_names (list(str), optional): list of names corresponding to the columns in the training data. - categorical_features (ints, optional): list of indices corresponding to the + categorical_features (list(int), optional): list of indices corresponding to the categorical columns. Values in these columns MUST be integers. kernel_width (int, optional): kernel width @@ -49,15 +51,14 @@ def __init__( the classifier is using. If not present, class names will be '0', '1', ... feature_selection (str, optional): feature selection - discretize_continuous (bool, optional): if True, all non-categorical features - will be discretized into quartiles. random_state (int or np.RandomState, optional): seed or random state kwargs: These parameters are passed on """ self.mode = mode init_instance_kwargs = utils.get_kwargs_applicable_to_function( - LimeTabularExplainer, kwargs) + LimeTabularExplainer, kwargs + ) # temporary solution for setting num_features and top_labels self.num_features = len(feature_names) @@ -79,12 +80,12 @@ def __init__( def explain( self, - model_or_function, - input_tabular, - labels=(1, ), - num_samples=5000, + model_or_function: Union[str, callable], + input_tabular: np.array, + labels: Iterable[int] = (1,), + num_samples: int = 5000, **kwargs, - ): + ) -> np.array: """Run the LIME explainer. Args: @@ -103,7 +104,8 @@ def explain( """ # run the explanation. explain_instance_kwargs = utils.get_kwargs_applicable_to_function( - self.explainer.explain_instance, kwargs) + self.explainer.explain_instance, kwargs + ) runner = utils.get_function(model_or_function) explanation = self.explainer.explain_instance( @@ -116,11 +118,11 @@ def explain( **explain_instance_kwargs, ) - if self.mode == 'regression': + if self.mode == "regression": local_exp = sorted(explanation.local_exp[1]) saliency = [i[1] for i in local_exp] - elif self.mode == 'classification': + elif self.mode == "classification": # extract scores from lime explainer saliency = [] for i in range(self.top_labels):