From c3ff31bbb91c23ccf1c1e707ceeda7bda0ea5e95 Mon Sep 17 00:00:00 2001 From: Antonin POCHE Date: Thu, 5 Oct 2023 15:20:46 +0200 Subject: [PATCH] feature viz objectives: fix issue 131 --- xplique/features_visualizations/objectives.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xplique/features_visualizations/objectives.py b/xplique/features_visualizations/objectives.py index 3e880e72..5e663536 100644 --- a/xplique/features_visualizations/objectives.py +++ b/xplique/features_visualizations/objectives.py @@ -121,9 +121,10 @@ def compile(self) -> Tuple[tf.keras.Model, Callable, List[str], Tuple]: def objective_function(model_outputs): loss = 0.0 for output_index in range(0, nb_sub_objectives): - loss += self.funcs[output_index](model_outputs[output_index], - masks[output_index]) * \ - multipliers[output_index] + outputs = model_outputs[output_index] + loss += self.funcs[output_index]( + outputs, tf.cast(masks[output_index], outputs.dtype)) + loss *= multipliers[output_index] return loss # the model outputs will be composed of the layers needed