Skip to content

Commit

Permalink
lint: clean-up opti-cam
Browse files Browse the repository at this point in the history
  • Loading branch information
Agustin-Picard committed Sep 9, 2024
1 parent cea7d21 commit 77e59ab
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions xplique/attributions/opti_cam.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Module related to the Opti-CAM method
https://www.sciencedirect.com/science/article/pii/S1077314224001826?via%3Dihub
"""
import tensorflow as tf
import numpy as np

Expand All @@ -8,7 +12,7 @@

_normalization_dict = {
'max_min': lambda x: (x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x)),
'sigmoid': lambda x: tf.nn.sigmoid(x),
'sigmoid': tf.nn.sigmoid,
'max': lambda x: x / tf.reduce_max(x),
}

Expand Down Expand Up @@ -118,26 +122,32 @@ def explain(self,
batch_size):
# initialize weights and optimization elements
current_batch_size = x_batch.shape[0]
weights = tf.Variable(0.5 * tf.ones((current_batch_size, 1, 1, self.conv_layer.output.shape[-1])),
trainable=True, dtype=tf.float32)
weights = tf.Variable(
0.5 * tf.ones((current_batch_size, 1, 1, self.conv_layer.output.shape[-1])
), trainable=True, dtype=tf.float32)
optimizer = tf.keras.optimizers.Adam(0.05)
for _ in range(self.n_iters):
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(weights)
feature_maps, logits = self.model(x_batch)
explanations = self._one_step_explanation(x_batch, feature_maps, weights)
explanations = tf.map_fn(self.normalize, explanations)
x_perturbed = tf.multiply(tf.tile(explanations, [1, 1, 1, x_batch.shape[-1]]), x_batch)
x_perturbed = tf.multiply(
tf.tile(explanations, [1, 1, 1, x_batch.shape[-1]]),
x_batch
)
_, logits_perturbed = self.model(x_perturbed)
logits, logits_perturbed = tf.reduce_sum(logits * y_batch, axis=1), tf.reduce_sum(logits_perturbed * y_batch, axis=1)
logits = tf.reduce_sum(logits * y_batch, axis=1)
logits_perturbed = tf.reduce_sum(logits_perturbed * y_batch, axis=1)
score = self.loss_fn(logits, logits_perturbed)
grads = tape.gradient(score, weights)
optimizer.apply_gradients([(grads, weights)])

# Generate the final explanations for the batch of images
explanations = self._one_step_explanation(x_batch, feature_maps, weights)
explanations = tf.map_fn(self.normalize, explanations)
opti_cams = explanations if opti_cams is None else tf.concat([opti_cams, explanations], axis=0)
opti_cams = explanations if opti_cams is None \
else tf.concat([opti_cams, explanations], axis=0)

return opti_cams

Expand Down

0 comments on commit 77e59ab

Please sign in to comment.