From 018c40421f7f810e5d6b4724b001b0b1e20f7c66 Mon Sep 17 00:00:00 2001 From: Eugene Liu <eugene.liu@intel.com> Date: Wed, 22 Jan 2025 14:29:04 +0000 Subject: [PATCH] Implement explainability features in DFine and RTDETR models, enhancing output with raw logits and saliency maps for better interpretability. --- src/otx/algo/detection/d_fine.py | 62 +++++++++++++++++++ src/otx/algo/detection/heads/dfine_decoder.py | 31 ++++++++-- .../algo/detection/heads/rtdetr_decoder.py | 33 +++++++--- src/otx/algo/detection/rtdetr.py | 34 +++++----- 4 files changed, 128 insertions(+), 32 deletions(-) diff --git a/src/otx/algo/detection/d_fine.py b/src/otx/algo/detection/d_fine.py index 5e16aa9c3c..1ff1d6bb53 100644 --- a/src/otx/algo/detection/d_fine.py +++ b/src/otx/algo/detection/d_fine.py @@ -157,6 +157,9 @@ def _customize_inputs( ) targets.append({"boxes": scaled_bboxes, "labels": ll}) + if self.explain_mode: + return {"entity": entity} + return { "images": entity.images, "targets": targets, @@ -185,6 +188,33 @@ def _customize_outputs( original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info] scores, bboxes, labels = self.model.postprocess(outputs, original_sizes) + if self.explain_mode: + if not isinstance(outputs, dict): + msg = f"Model output should be a dict, but got {type(outputs)}." + raise ValueError(msg) + + if "feature_vector" not in outputs: + msg = "No feature vector in the model output." + raise ValueError(msg) + + if "saliency_map" not in outputs: + msg = "No saliency maps in the model output." + raise ValueError(msg) + + saliency_map = outputs["saliency_map"].detach().cpu().numpy() + feature_vector = outputs["feature_vector"].detach().cpu().numpy() + + return DetBatchPredEntity( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + bboxes=bboxes, + labels=labels, + feature_vector=feature_vector, + saliency_map=saliency_map, + ) + return DetBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -306,3 +336,35 @@ def _optimization_config(self) -> dict[str, Any]: }, }, } + + @staticmethod + def _forward_explain_detection( + self, # noqa: ANN001 + entity: DetBatchDataEntity, + mode: str = "tensor", # noqa: ARG004 + ) -> dict[str, torch.Tensor]: + """Forward function for explainable detection model.""" + backbone_feats = self.encoder(self.backbone(entity.images)) + predictions = self.decoder(backbone_feats, explain_mode=True) + + feature_vector = self.feature_vector_fn(backbone_feats) + + splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats] + + # Permute and split logits in one line + raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1) + + # Reshape each split in a list comprehension + raw_logits = [ + logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats) + ] + + saliency_map = self.explain_fn(raw_logits) + predictions.update( + { + "feature_vector": feature_vector, + "saliency_map": saliency_map, + }, + ) + + return predictions diff --git a/src/otx/algo/detection/heads/dfine_decoder.py b/src/otx/algo/detection/heads/dfine_decoder.py index d28e0cf386..e2d8f9dd66 100644 --- a/src/otx/algo/detection/heads/dfine_decoder.py +++ b/src/otx/algo/detection/heads/dfine_decoder.py @@ -723,7 +723,7 @@ def _get_decoder_input( enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1) content = torch.concat([denoising_logits, content], dim=1) - return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list + return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list, enc_outputs_logits def _select_topk( self, @@ -762,8 +762,22 @@ def _select_topk( return topk_memory, topk_logits, topk_anchors - def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) -> dict[str, Tensor]: - """Forward pass of the DFine Transformer module.""" + def forward( + self, + feats: Tensor, + targets: list[dict[str, Tensor]] | None = None, + explain_mode: bool = False, + ) -> dict[str, Tensor]: + """Forward function of the D-FINE Decoder Transformer Module. + + Args: + feats (Tensor): Feature maps. + targets (list[dict[str, Tensor]] | None, optional): target annotations. Defaults to None. + explain_mode (bool, optional): Whether to return raw logits for explanation. Defaults to False. + + Returns: + dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes. + """ # input projection and embedding memory, spatial_shapes = self._get_encoder_input(feats) @@ -781,7 +795,13 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) else: denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None - init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = self._get_decoder_input( + ( + init_ref_contents, + init_ref_points_unact, + enc_topk_bboxes_list, + enc_topk_logits_list, + raw_logits, + ) = self._get_decoder_input( memory, spatial_shapes, denoising_logits, @@ -858,6 +878,9 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) "pred_boxes": out_bboxes[-1], } + if explain_mode: + out["raw_logits"] = raw_logits + return out @torch.jit.unused diff --git a/src/otx/algo/detection/heads/rtdetr_decoder.py b/src/otx/algo/detection/heads/rtdetr_decoder.py index dd5cf2f199..485c712cae 100644 --- a/src/otx/algo/detection/heads/rtdetr_decoder.py +++ b/src/otx/algo/detection/heads/rtdetr_decoder.py @@ -546,10 +546,10 @@ def _get_decoder_input( output_memory = self.enc_output(memory) - enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_logits = self.enc_score_head(output_memory) enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors - _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1) + _, topk_ind = torch.topk(enc_outputs_logits.max(-1).values, self.num_queries, dim=1) reference_points_unact = enc_outputs_coord_unact.gather( dim=1, @@ -560,9 +560,9 @@ def _get_decoder_input( if denoising_bbox_unact is not None: reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) - enc_topk_logits = enc_outputs_class.gather( + enc_topk_logits = enc_outputs_logits.gather( dim=1, - index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]), + index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_logits.shape[-1]), ) # extract region features @@ -575,10 +575,24 @@ def _get_decoder_input( if denoising_class is not None: target = torch.concat([denoising_class, target], 1) - return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits + return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits, enc_outputs_logits - def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None) -> torch.Tensor: - """Forward pass of the RTDETRTransformer module.""" + def forward( + self, + feats: torch.Tensor, + targets: list[dict[str, torch.Tensor]] | None = None, + explain_mode: bool = False, + ) -> dict[str, torch.Tensor]: + """Forward function of RTDETRTransformer. + + Args: + feats (Tensor): Input features. + targets (List[Dict[str, Tensor]]): List of target dictionaries. + explain_mode (bool): Whether to return raw logits for explanation. + + Returns: + dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes. + """ # input projection and embedding (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats) @@ -596,7 +610,7 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | else: denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None - target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input( + target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits, raw_logits = self._get_decoder_input( memory, spatial_shapes, denoising_class, @@ -630,6 +644,9 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes) out["dn_meta"] = dn_meta + if explain_mode: + out["raw_logits"] = raw_logits + return out @torch.jit.unused diff --git a/src/otx/algo/detection/rtdetr.py b/src/otx/algo/detection/rtdetr.py index b89f68915c..e4eb1abd61 100644 --- a/src/otx/algo/detection/rtdetr.py +++ b/src/otx/algo/detection/rtdetr.py @@ -7,7 +7,7 @@ import copy import re -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Literal import torch from torch import Tensor, nn @@ -302,23 +302,6 @@ def _optimization_config(self) -> dict[str, Any]: """PTQ config for RT-DETR.""" return {"model_type": "transformer"} - @torch.no_grad() - def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: - """Forward function for the score head of the model.""" - x = x.squeeze() - return self.model.decoder.dec_score_head[-1](x) - - def get_explain_fn(self) -> Callable: - """Returns explain function.""" - from otx.algo.explain.explain_algo import ReciproCAM - - explainer = ReciproCAM( - head_forward_fn=self.head_forward_fn, - num_classes=self.num_classes, - optimize_gap=True, - ) - return explainer.func - @staticmethod def _forward_explain_detection( self, # noqa: ANN001 @@ -327,10 +310,21 @@ def _forward_explain_detection( ) -> dict[str, torch.Tensor]: """Forward function for explainable detection model.""" backbone_feats = self.encoder(self.backbone(entity.images)) - predictions = self.decoder(backbone_feats) + predictions = self.decoder(backbone_feats, explain_mode=True) feature_vector = self.feature_vector_fn(backbone_feats) - saliency_map = self.explain_fn(backbone_feats[-1]) + + splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats] + + # Permute and split logits in one line + raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1) + + # Reshape each split in a list comprehension + raw_logits = [ + logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats) + ] + + saliency_map = self.explain_fn(raw_logits) predictions.update( { "feature_vector": feature_vector,