From 77e59abd112571cb79262c3f1e6877960d50e805 Mon Sep 17 00:00:00 2001 From: Agustin Picard Date: Mon, 9 Sep 2024 16:00:33 +0200 Subject: [PATCH] lint: clean-up opti-cam --- xplique/attributions/opti_cam.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/xplique/attributions/opti_cam.py b/xplique/attributions/opti_cam.py index a8cb5452..c8fe8262 100644 --- a/xplique/attributions/opti_cam.py +++ b/xplique/attributions/opti_cam.py @@ -1,3 +1,7 @@ +""" +Module related to the Opti-CAM method +https://www.sciencedirect.com/science/article/pii/S1077314224001826?via%3Dihub +""" import tensorflow as tf import numpy as np @@ -8,7 +12,7 @@ _normalization_dict = { 'max_min': lambda x: (x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x)), - 'sigmoid': lambda x: tf.nn.sigmoid(x), + 'sigmoid': tf.nn.sigmoid, 'max': lambda x: x / tf.reduce_max(x), } @@ -118,8 +122,9 @@ def explain(self, batch_size): # initialize weights and optimization elements current_batch_size = x_batch.shape[0] - weights = tf.Variable(0.5 * tf.ones((current_batch_size, 1, 1, self.conv_layer.output.shape[-1])), - trainable=True, dtype=tf.float32) + weights = tf.Variable( + 0.5 * tf.ones((current_batch_size, 1, 1, self.conv_layer.output.shape[-1]) + ), trainable=True, dtype=tf.float32) optimizer = tf.keras.optimizers.Adam(0.05) for _ in range(self.n_iters): with tf.GradientTape(watch_accessed_variables=False) as tape: @@ -127,9 +132,13 @@ def explain(self, feature_maps, logits = self.model(x_batch) explanations = self._one_step_explanation(x_batch, feature_maps, weights) explanations = tf.map_fn(self.normalize, explanations) - x_perturbed = tf.multiply(tf.tile(explanations, [1, 1, 1, x_batch.shape[-1]]), x_batch) + x_perturbed = tf.multiply( + tf.tile(explanations, [1, 1, 1, x_batch.shape[-1]]), + x_batch + ) _, logits_perturbed = self.model(x_perturbed) - logits, logits_perturbed = tf.reduce_sum(logits * y_batch, axis=1), tf.reduce_sum(logits_perturbed * y_batch, axis=1) + logits = tf.reduce_sum(logits * y_batch, axis=1) + logits_perturbed = tf.reduce_sum(logits_perturbed * y_batch, axis=1) score = self.loss_fn(logits, logits_perturbed) grads = tape.gradient(score, weights) optimizer.apply_gradients([(grads, weights)]) @@ -137,7 +146,8 @@ def explain(self, # Generate the final explanations for the batch of images explanations = self._one_step_explanation(x_batch, feature_maps, weights) explanations = tf.map_fn(self.normalize, explanations) - opti_cams = explanations if opti_cams is None else tf.concat([opti_cams, explanations], axis=0) + opti_cams = explanations if opti_cams is None \ + else tf.concat([opti_cams, explanations], axis=0) return opti_cams