diff --git a/tests/attributions/test_common.py b/tests/attributions/test_common.py index 83f25304..9add50c4 100644 --- a/tests/attributions/test_common.py +++ b/tests/attributions/test_common.py @@ -57,6 +57,9 @@ def test_common(): # all explanations returned must be either a tf.Tensor or ndarray assert isinstance(explanations, (tf.Tensor, np.ndarray)) + # we should have one explanation for each inputs + assert len(explanations) == len(inputs_np) + def test_batch_size(): """Ensure the functioning of attributions for special batch size cases""" @@ -89,12 +92,13 @@ def test_batch_size(): ] for method in methods: - try: - explanations = method.explain(inputs, targets) - except: - raise AssertionError( - "Explanation failed for method ", method.__class__.__name__, - " batch size ", bs) + explanations = method.explain(inputs, targets) + + # all explanations returned must be either a tf.Tensor or ndarray + assert isinstance(explanations, (tf.Tensor, np.ndarray)) + + # we should have one explanation for each inputs + assert len(explanations) == len(inputs) def test_model_caching(): @@ -118,4 +122,4 @@ def test_model_caching(): # ensure that there no more than one key has been added assert (len( - BlackBoxExplainer._cache_models) == cache_len_before + 1) # pylint: disable=protected-access + BlackBoxExplainer._cache_models) == cache_len_before + 1) # pylint: disable=protected-access \ No newline at end of file diff --git a/xplique/commons/data_conversion.py b/xplique/commons/data_conversion.py index 9787451e..517f86ad 100644 --- a/xplique/commons/data_conversion.py +++ b/xplique/commons/data_conversion.py @@ -31,15 +31,13 @@ def tensor_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray], # deal with tf.data.Dataset if isinstance(inputs, tf.data.Dataset): - # if the dataset as 4 dimensions, assume it is batched - dataset_shape = inputs.element_spec[0].shape - if len(dataset_shape) == 4: + # try to know if the dataset is batched, if it is the case we unbatch + if hasattr(inputs, '_batch_size'): inputs = inputs.unbatch() # unpack the dataset, assume we have tuple of (input, target) - targets = [target for inp, target in inputs] - inputs = [inp for inp, target in inputs] + targets = [target for _, target in inputs] + inputs = [inp for inp, _ in inputs] - # deal with numpy array inputs = tf.cast(inputs, tf.float32) targets = tf.cast(targets, tf.float32)