diff --git a/xplique/features_visualizations/objectives.py b/xplique/features_visualizations/objectives.py index 3e880e72..5e663536 100644 --- a/xplique/features_visualizations/objectives.py +++ b/xplique/features_visualizations/objectives.py @@ -121,9 +121,10 @@ def compile(self) -> Tuple[tf.keras.Model, Callable, List[str], Tuple]: def objective_function(model_outputs): loss = 0.0 for output_index in range(0, nb_sub_objectives): - loss += self.funcs[output_index](model_outputs[output_index], - masks[output_index]) * \ - multipliers[output_index] + outputs = model_outputs[output_index] + loss += self.funcs[output_index]( + outputs, tf.cast(masks[output_index], outputs.dtype)) + loss *= multipliers[output_index] return loss # the model outputs will be composed of the layers needed