Skip to content

Commit

Permalink
add typing to arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang committed Dec 7, 2023
1 parent 85f152a commit 7614265
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions dianna/methods/lime_tabular.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""LIME tabular explainer."""
from typing import Iterable
from typing import Union
import numpy as np
from lime.lime_tabular import LimeTabularExplainer
from dianna import utils
Expand All @@ -9,18 +11,18 @@ class LIMETabular:

def __init__(
self,
training_data,
mode='classification',
feature_names=None,
categorical_features=None,
kernel_width=25,
kernel=None,
verbose=False,
class_names=None,
feature_selection='auto',
random_state=None,
training_data: np.array,
mode: str = "classification",
feature_names: list[int] = None,
categorical_features: list[int] = None,
kernel_width: int = 25,
kernel: callable = None,
verbose: bool = False,
class_names: list[str] = None,
feature_selection: str = "auto",
random_state: int = None,
**kwargs,
):
) -> None:
"""Initializes Lime explainer.
For numerical features, perturb them by sampling from a Normal(0,1) and
Expand All @@ -37,9 +39,9 @@ def __init__(
Args:
training_data (np.array): numpy 2d array
mode (str, optional): "classification" or "regression"
feature_names (strings, optional): list of names corresponding to the columns
feature_names (list(str), optional): list of names corresponding to the columns
in the training data.
categorical_features (ints, optional): list of indices corresponding to the
categorical_features (list(int), optional): list of indices corresponding to the
categorical columns. Values in these
columns MUST be integers.
kernel_width (int, optional): kernel width
Expand All @@ -49,15 +51,14 @@ def __init__(
the classifier is using. If not present, class names
will be '0', '1', ...
feature_selection (str, optional): feature selection
discretize_continuous (bool, optional): if True, all non-categorical features
will be discretized into quartiles.
random_state (int or np.RandomState, optional): seed or random state
kwargs: These parameters are passed on
"""
self.mode = mode
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
LimeTabularExplainer, kwargs)
LimeTabularExplainer, kwargs
)

# temporary solution for setting num_features and top_labels
self.num_features = len(feature_names)
Expand All @@ -79,12 +80,12 @@ def __init__(

def explain(
self,
model_or_function,
input_tabular,
labels=(1, ),
num_samples=5000,
model_or_function: Union[str, callable],
input_tabular: np.array,
labels: Iterable[int] = (1,),
num_samples: int = 5000,
**kwargs,
):
) -> np.array:
"""Run the LIME explainer.
Args:
Expand All @@ -103,7 +104,8 @@ def explain(
"""
# run the explanation.
explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.explain_instance, kwargs)
self.explainer.explain_instance, kwargs
)
runner = utils.get_function(model_or_function)

explanation = self.explainer.explain_instance(
Expand All @@ -116,11 +118,11 @@ def explain(
**explain_instance_kwargs,
)

if self.mode == 'regression':
if self.mode == "regression":
local_exp = sorted(explanation.local_exp[1])
saliency = [i[1] for i in local_exp]

elif self.mode == 'classification':
elif self.mode == "classification":
# extract scores from lime explainer
saliency = []
for i in range(self.top_labels):
Expand Down

0 comments on commit 7614265

Please sign in to comment.