Skip to content

Commit

Permalink
lime: refacto for pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonin POCHE committed Nov 7, 2023
1 parent 6855221 commit 1e3c993
Showing 1 changed file with 46 additions and 112 deletions.
158 changes: 46 additions & 112 deletions xplique/attributions/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,13 @@ def __init__(
self.ref_value = ref_value
self.nb_samples = nb_samples

@sanitize_input_output
def explain(self,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> tf.Tensor:
self.batch_size = self.batch_size or self.nb_samples

def _set_shape_dependant_parameters(self,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray]):
"""
This method attributes the output of the model with given targets
to the inputs of the model using the approach described above,
training an interpretable model and returning a representation of the
interpretable model.
Set default values for parameters dependant on the data type, thus the inputs shape.
I.e. `ref_value` and `map_to_interpret_space`.
Parameters
----------
Expand All @@ -161,85 +159,38 @@ def explain(self,
If Dataset, targets should not be provided (included in Dataset).
Expected shape among (N, W), (N, T, W), (N, H, W, C).
More information in the documentation.
targets
Tensor or Array. One-hot encoding of the model's output from which an explanation
is desired. One encoding per input and only one output at a time. Therefore,
the expected shape is (N, output_size).
More information in the documentation.
Returns
-------
explanations
Interpretable coefficients, same shape as the inputs, except for the channels.
Coefficients of the interpretable model. Those coefficients having the size of the
interpretable space will be given the same value to coefficient which were grouped
together (e.g belonging to the same super-pixel).
"""

# check if inputs are tabular, time-series or has shape (N, H, W, C)
is_tabular = len(inputs.shape) == 2
is_time_series = len(inputs.shape) == 3
has_channels = len(inputs.shape) == 4 and inputs.shape[-1] == 3

if has_channels:
# default quickshift segmentation for image
if self.map_to_interpret_space is None:
self.map_to_interpret_space = Lime._default_image_map_to_interpret_space
# if inputs have channels ensure
if self.ref_value is None:
if inputs.shape[-1] == 3:
# grey pixel
ref_value = tf.ones(inputs.shape[-1])*0.5
else:
ref_value = tf.zeros(inputs.shape[-1])
else:
assert(
self.ref_value.shape[0] == inputs.shape[-1]
),"The dimension of ref_values must match inputs (C, )"
ref_value = tf.cast(self.ref_value, tf.float32)
if self.ref_value is not None:
if len(inputs.shape) == 4:
assert(self.ref_value.shape[0] == inputs.shape[-1]),\
"The dimension of ref_values must match inputs (C, )"
self.ref_value = tf.cast(self.ref_value, tf.float32)
else:
if self.map_to_interpret_space is None:
if is_tabular:
self.map_to_interpret_space = Lime._default_tab_map_to_interpret_space
elif is_time_series:
self.map_to_interpret_space = Lime._default_time_series_map_to_interpret_space
else:
if len(inputs.shape) in [2, 3]: # Tabular data or time series
self.ref_value = tf.zeros(1)
elif len(inputs.shape) == 4: # Image
if inputs.shape[-1] == 3: # RGB image
# grey pixel
self.ref_value = tf.fill(inputs.shape[-1], 0.5)
elif inputs.shape[-1] == 1: # Black and white image
self.ref_value = tf.zeros(inputs.shape[-1])

if self.map_to_interpret_space is None:
if len(inputs.shape) == 2: # Tabular data
self.map_to_interpret_space = Lime._default_tab_map_to_interpret_space
elif len(inputs.shape) == 3: # Time series
self.map_to_interpret_space = Lime._default_time_series_map_to_interpret_space
elif len(inputs.shape) == 4: # Image
if inputs.shape[-1] == 3: # RGB image
self.map_to_interpret_space = Lime._default_image_map_to_interpret_space
elif inputs.shape[-1] == 1: # Black and white image
self.map_to_interpret_space = Lime._default_2dimage_map_to_interpret_space

if self.ref_value is None:
ref_value = tf.zeros(1)
else:
ref_value = tf.cast(self.ref_value, tf.float32)

batch_size = self.batch_size or self.nb_samples

return Lime._compute(self.model,
batch_size,
inputs,
targets,
self.inference_function,
self.interpretable_model,
self.similarity_kernel,
self.pertub_func,
ref_value,
self.map_to_interpret_space,
self.nb_samples,
)

@staticmethod
def _compute(model: Callable,
batch_size: int,
inputs: tf.Tensor,
targets: tf.Tensor,
inference_function: Callable,
interpretable_model: Callable,
similarity_kernel: Callable[[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor],
pertub_func: Callable[[Union[int, tf.Tensor],int], tf.Tensor],
ref_value: tf.Tensor,
map_to_interpret_space: Callable[[tf.Tensor], tf.Tensor],
nb_samples: int,
) -> tf.Tensor:
# pylint: disable=R0913
@sanitize_input_output
def explain(self,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> tf.Tensor:
"""
This method attributes the output of the model with given targets
to the inputs of the model using the approach described above,
Expand All @@ -248,50 +199,33 @@ def _compute(model: Callable,
Parameters
----------
model
The model from which we want to obtain explanations
inputs
Dataset, Tensor or Array. Input samples to be explained.
If Dataset, targets should not be provided (included in Dataset).
Expected shape among (N, W), (N, T, W), (N, H, W, C).
More information in the documentation.
targets
Tensor or Array. One-hot encoding of the model's output from which an explanation
is desired. One encoding per input and only one output at a time.
is desired. One encoding per input and only one output at a time. Therefore,
the expected shape is (N, output_size).
More information in the documentation.
inference_function
Function that allows to get the probability output of the model
interpretable_model
Model object to train interpretable model.
similarity_kernel
Function which considering an input, perturbed instances of thoses samples and the
interpretable version of those perturbed samples compute the similarities.
pertub_function
Function which generate perturbed interpretable samples in the interpretation space.
ref_values
It defines reference value which replaces each feature when the corresponding
interpretable feature is set to 0.
map_to_interpret_space
Function which group features of an input corresponding to the same interpretable
feature (e.g super-pixel).
nb_samples
The number of perturbed samples you want to generate for each input sample.
Returns
-------
explanations
A Tensor
Interpretable coefficients, same shape as the inputs, except for the channels.
Coefficients of the interpretable model. Those coefficients having the size of the
interpretable space will be given the same value to coefficient which were grouped
together (e.g belonging to the same super-pixel).
"""
self._set_shape_dependant_parameters(inputs)
explanations = []

for inp, target in tf.data.Dataset.from_tensor_slices(
(inputs, targets)
):
# get the mapping of the current input
mapping = map_to_interpret_space(inp)
mapping = self.map_to_interpret_space(inp)
# get the number of interpretable feature
num_features = tf.reduce_max(mapping) + tf.ones(1, dtype=tf.int32)
if tf.greater(num_features, 10000):
Expand All @@ -302,35 +236,35 @@ def _compute(model: Callable,
)

# get perturbed interpretable samples of the input
interpret_samples = pertub_func(num_features, nb_samples)
interpret_samples = self.pertub_func(num_features, self.nb_samples)

# get the perturbed targets value and the similarities value
perturbed_targets = []
similarities = []
for int_samples in tf.data.Dataset.from_tensor_slices(
interpret_samples
).batch(batch_size):
).batch(self.batch_size):

masks = Lime._get_masks(int_samples, mapping)
perturbed_samples = Lime._apply_masks(inp, masks, ref_value)
perturbed_samples = Lime._apply_masks(inp, masks, self.ref_value)

augmented_target = tf.expand_dims(target, axis=0)
augmented_target = tf.repeat(augmented_target, len(perturbed_samples), axis=0)

batch_perturbed_targets = inference_function(model,
perturbed_samples,
augmented_target)
batch_perturbed_targets = self.inference_function(self.model,
perturbed_samples,
augmented_target)

perturbed_targets.append(batch_perturbed_targets)

batch_similarities = similarity_kernel(inp, int_samples, perturbed_samples)
batch_similarities = self.similarity_kernel(inp, int_samples, perturbed_samples)
similarities.append(batch_similarities)

perturbed_targets = tf.concat(perturbed_targets, axis=0)
similarities = tf.concat(similarities, axis=0)

# train the interpretable model
explain_model = interpretable_model
explain_model = self.interpretable_model

explain_model.fit(
interpret_samples.numpy(),
Expand Down

0 comments on commit 1e3c993

Please sign in to comment.