From 018c40421f7f810e5d6b4724b001b0b1e20f7c66 Mon Sep 17 00:00:00 2001
From: Eugene Liu <eugene.liu@intel.com>
Date: Wed, 22 Jan 2025 14:29:04 +0000
Subject: [PATCH] Implement explainability features in DFine and RTDETR models,
 enhancing output with raw logits and saliency maps for better
 interpretability.

---
 src/otx/algo/detection/d_fine.py              | 62 +++++++++++++++++++
 src/otx/algo/detection/heads/dfine_decoder.py | 31 ++++++++--
 .../algo/detection/heads/rtdetr_decoder.py    | 33 +++++++---
 src/otx/algo/detection/rtdetr.py              | 34 +++++-----
 4 files changed, 128 insertions(+), 32 deletions(-)

diff --git a/src/otx/algo/detection/d_fine.py b/src/otx/algo/detection/d_fine.py
index 5e16aa9c3c..1ff1d6bb53 100644
--- a/src/otx/algo/detection/d_fine.py
+++ b/src/otx/algo/detection/d_fine.py
@@ -157,6 +157,9 @@ def _customize_inputs(
                 )
             targets.append({"boxes": scaled_bboxes, "labels": ll})
 
+        if self.explain_mode:
+            return {"entity": entity}
+
         return {
             "images": entity.images,
             "targets": targets,
@@ -185,6 +188,33 @@ def _customize_outputs(
         original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
         scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)
 
+        if self.explain_mode:
+            if not isinstance(outputs, dict):
+                msg = f"Model output should be a dict, but got {type(outputs)}."
+                raise ValueError(msg)
+
+            if "feature_vector" not in outputs:
+                msg = "No feature vector in the model output."
+                raise ValueError(msg)
+
+            if "saliency_map" not in outputs:
+                msg = "No saliency maps in the model output."
+                raise ValueError(msg)
+
+            saliency_map = outputs["saliency_map"].detach().cpu().numpy()
+            feature_vector = outputs["feature_vector"].detach().cpu().numpy()
+
+            return DetBatchPredEntity(
+                batch_size=len(outputs),
+                images=inputs.images,
+                imgs_info=inputs.imgs_info,
+                scores=scores,
+                bboxes=bboxes,
+                labels=labels,
+                feature_vector=feature_vector,
+                saliency_map=saliency_map,
+            )
+
         return DetBatchPredEntity(
             batch_size=len(outputs),
             images=inputs.images,
@@ -306,3 +336,35 @@ def _optimization_config(self) -> dict[str, Any]:
                 },
             },
         }
+
+    @staticmethod
+    def _forward_explain_detection(
+        self,  # noqa: ANN001
+        entity: DetBatchDataEntity,
+        mode: str = "tensor",  # noqa: ARG004
+    ) -> dict[str, torch.Tensor]:
+        """Forward function for explainable detection model."""
+        backbone_feats = self.encoder(self.backbone(entity.images))
+        predictions = self.decoder(backbone_feats, explain_mode=True)
+
+        feature_vector = self.feature_vector_fn(backbone_feats)
+
+        splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]
+
+        # Permute and split logits in one line
+        raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1)
+
+        # Reshape each split in a list comprehension
+        raw_logits = [
+            logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
+        ]
+
+        saliency_map = self.explain_fn(raw_logits)
+        predictions.update(
+            {
+                "feature_vector": feature_vector,
+                "saliency_map": saliency_map,
+            },
+        )
+
+        return predictions
diff --git a/src/otx/algo/detection/heads/dfine_decoder.py b/src/otx/algo/detection/heads/dfine_decoder.py
index d28e0cf386..e2d8f9dd66 100644
--- a/src/otx/algo/detection/heads/dfine_decoder.py
+++ b/src/otx/algo/detection/heads/dfine_decoder.py
@@ -723,7 +723,7 @@ def _get_decoder_input(
             enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1)
             content = torch.concat([denoising_logits, content], dim=1)
 
-        return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list
+        return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list, enc_outputs_logits
 
     def _select_topk(
         self,
@@ -762,8 +762,22 @@ def _select_topk(
 
         return topk_memory, topk_logits, topk_anchors
 
-    def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) -> dict[str, Tensor]:
-        """Forward pass of the DFine Transformer module."""
+    def forward(
+        self,
+        feats: Tensor,
+        targets: list[dict[str, Tensor]] | None = None,
+        explain_mode: bool = False,
+    ) -> dict[str, Tensor]:
+        """Forward function of the D-FINE Decoder Transformer Module.
+
+        Args:
+            feats (Tensor): Feature maps.
+            targets (list[dict[str, Tensor]] | None, optional): target annotations. Defaults to None.
+            explain_mode (bool, optional): Whether to return raw logits for explanation. Defaults to False.
+
+        Returns:
+            dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
+        """
         # input projection and embedding
         memory, spatial_shapes = self._get_encoder_input(feats)
 
@@ -781,7 +795,13 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
         else:
             denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
 
-        init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = self._get_decoder_input(
+        (
+            init_ref_contents,
+            init_ref_points_unact,
+            enc_topk_bboxes_list,
+            enc_topk_logits_list,
+            raw_logits,
+        ) = self._get_decoder_input(
             memory,
             spatial_shapes,
             denoising_logits,
@@ -858,6 +878,9 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
                 "pred_boxes": out_bboxes[-1],
             }
 
+        if explain_mode:
+            out["raw_logits"] = raw_logits
+
         return out
 
     @torch.jit.unused
diff --git a/src/otx/algo/detection/heads/rtdetr_decoder.py b/src/otx/algo/detection/heads/rtdetr_decoder.py
index dd5cf2f199..485c712cae 100644
--- a/src/otx/algo/detection/heads/rtdetr_decoder.py
+++ b/src/otx/algo/detection/heads/rtdetr_decoder.py
@@ -546,10 +546,10 @@ def _get_decoder_input(
 
         output_memory = self.enc_output(memory)
 
-        enc_outputs_class = self.enc_score_head(output_memory)
+        enc_outputs_logits = self.enc_score_head(output_memory)
         enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
 
-        _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
+        _, topk_ind = torch.topk(enc_outputs_logits.max(-1).values, self.num_queries, dim=1)
 
         reference_points_unact = enc_outputs_coord_unact.gather(
             dim=1,
@@ -560,9 +560,9 @@ def _get_decoder_input(
         if denoising_bbox_unact is not None:
             reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
 
-        enc_topk_logits = enc_outputs_class.gather(
+        enc_topk_logits = enc_outputs_logits.gather(
             dim=1,
-            index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]),
+            index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_logits.shape[-1]),
         )
 
         # extract region features
@@ -575,10 +575,24 @@ def _get_decoder_input(
         if denoising_class is not None:
             target = torch.concat([denoising_class, target], 1)
 
-        return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
+        return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits, enc_outputs_logits
 
-    def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None) -> torch.Tensor:
-        """Forward pass of the RTDETRTransformer module."""
+    def forward(
+        self,
+        feats: torch.Tensor,
+        targets: list[dict[str, torch.Tensor]] | None = None,
+        explain_mode: bool = False,
+    ) -> dict[str, torch.Tensor]:
+        """Forward function of RTDETRTransformer.
+
+        Args:
+            feats (Tensor): Input features.
+            targets (List[Dict[str, Tensor]]): List of target dictionaries.
+            explain_mode (bool): Whether to return raw logits for explanation.
+
+        Returns:
+            dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
+        """
         # input projection and embedding
         (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)
 
@@ -596,7 +610,7 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
         else:
             denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
 
-        target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input(
+        target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits, raw_logits = self._get_decoder_input(
             memory,
             spatial_shapes,
             denoising_class,
@@ -630,6 +644,9 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
                 out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
                 out["dn_meta"] = dn_meta
 
+        if explain_mode:
+            out["raw_logits"] = raw_logits
+
         return out
 
     @torch.jit.unused
diff --git a/src/otx/algo/detection/rtdetr.py b/src/otx/algo/detection/rtdetr.py
index b89f68915c..e4eb1abd61 100644
--- a/src/otx/algo/detection/rtdetr.py
+++ b/src/otx/algo/detection/rtdetr.py
@@ -7,7 +7,7 @@
 
 import copy
 import re
-from typing import TYPE_CHECKING, Any, Callable, Literal
+from typing import TYPE_CHECKING, Any, Literal
 
 import torch
 from torch import Tensor, nn
@@ -302,23 +302,6 @@ def _optimization_config(self) -> dict[str, Any]:
         """PTQ config for RT-DETR."""
         return {"model_type": "transformer"}
 
-    @torch.no_grad()
-    def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor:
-        """Forward function for the score head of the model."""
-        x = x.squeeze()
-        return self.model.decoder.dec_score_head[-1](x)
-
-    def get_explain_fn(self) -> Callable:
-        """Returns explain function."""
-        from otx.algo.explain.explain_algo import ReciproCAM
-
-        explainer = ReciproCAM(
-            head_forward_fn=self.head_forward_fn,
-            num_classes=self.num_classes,
-            optimize_gap=True,
-        )
-        return explainer.func
-
     @staticmethod
     def _forward_explain_detection(
         self,  # noqa: ANN001
@@ -327,10 +310,21 @@ def _forward_explain_detection(
     ) -> dict[str, torch.Tensor]:
         """Forward function for explainable detection model."""
         backbone_feats = self.encoder(self.backbone(entity.images))
-        predictions = self.decoder(backbone_feats)
+        predictions = self.decoder(backbone_feats, explain_mode=True)
 
         feature_vector = self.feature_vector_fn(backbone_feats)
-        saliency_map = self.explain_fn(backbone_feats[-1])
+
+        splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]
+
+        # Permute and split logits in one line
+        raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1)
+
+        # Reshape each split in a list comprehension
+        raw_logits = [
+            logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
+        ]
+
+        saliency_map = self.explain_fn(raw_logits)
         predictions.update(
             {
                 "feature_vector": feature_vector,