Skip to content

Commit

Permalink
Implement explainability features in DFine and RTDETR models, enhanci…
Browse files Browse the repository at this point in the history
…ng output with raw logits and saliency maps for better interpretability.
  • Loading branch information
eugene123tw committed Jan 22, 2025
1 parent 7ca1b7d commit 018c404
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 32 deletions.
62 changes: 62 additions & 0 deletions src/otx/algo/detection/d_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def _customize_inputs(
)
targets.append({"boxes": scaled_bboxes, "labels": ll})

if self.explain_mode:
return {"entity": entity}

return {
"images": entity.images,
"targets": targets,
Expand Down Expand Up @@ -185,6 +188,33 @@ def _customize_outputs(
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)

if self.explain_mode:
if not isinstance(outputs, dict):
msg = f"Model output should be a dict, but got {type(outputs)}."
raise ValueError(msg)

if "feature_vector" not in outputs:
msg = "No feature vector in the model output."
raise ValueError(msg)

if "saliency_map" not in outputs:
msg = "No saliency maps in the model output."
raise ValueError(msg)

saliency_map = outputs["saliency_map"].detach().cpu().numpy()
feature_vector = outputs["feature_vector"].detach().cpu().numpy()

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
labels=labels,
feature_vector=feature_vector,
saliency_map=saliency_map,
)

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
Expand Down Expand Up @@ -306,3 +336,35 @@ def _optimization_config(self) -> dict[str, Any]:
},
},
}

@staticmethod
def _forward_explain_detection(
self, # noqa: ANN001
entity: DetBatchDataEntity,
mode: str = "tensor", # noqa: ARG004
) -> dict[str, torch.Tensor]:
"""Forward function for explainable detection model."""
backbone_feats = self.encoder(self.backbone(entity.images))
predictions = self.decoder(backbone_feats, explain_mode=True)

feature_vector = self.feature_vector_fn(backbone_feats)

splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]

# Permute and split logits in one line
raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1)

# Reshape each split in a list comprehension
raw_logits = [
logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
]

saliency_map = self.explain_fn(raw_logits)
predictions.update(
{
"feature_vector": feature_vector,
"saliency_map": saliency_map,
},
)

return predictions
31 changes: 27 additions & 4 deletions src/otx/algo/detection/heads/dfine_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def _get_decoder_input(
enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1)
content = torch.concat([denoising_logits, content], dim=1)

return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list
return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list, enc_outputs_logits

def _select_topk(
self,
Expand Down Expand Up @@ -762,8 +762,22 @@ def _select_topk(

return topk_memory, topk_logits, topk_anchors

def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) -> dict[str, Tensor]:
"""Forward pass of the DFine Transformer module."""
def forward(
self,
feats: Tensor,
targets: list[dict[str, Tensor]] | None = None,
explain_mode: bool = False,
) -> dict[str, Tensor]:
"""Forward function of the D-FINE Decoder Transformer Module.
Args:
feats (Tensor): Feature maps.
targets (list[dict[str, Tensor]] | None, optional): target annotations. Defaults to None.
explain_mode (bool, optional): Whether to return raw logits for explanation. Defaults to False.
Returns:
dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
"""
# input projection and embedding
memory, spatial_shapes = self._get_encoder_input(feats)

Expand All @@ -781,7 +795,13 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
else:
denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None

init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = self._get_decoder_input(
(
init_ref_contents,
init_ref_points_unact,
enc_topk_bboxes_list,
enc_topk_logits_list,
raw_logits,
) = self._get_decoder_input(
memory,
spatial_shapes,
denoising_logits,
Expand Down Expand Up @@ -858,6 +878,9 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
"pred_boxes": out_bboxes[-1],
}

if explain_mode:
out["raw_logits"] = raw_logits

return out

@torch.jit.unused
Expand Down
33 changes: 25 additions & 8 deletions src/otx/algo/detection/heads/rtdetr_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,10 +546,10 @@ def _get_decoder_input(

output_memory = self.enc_output(memory)

enc_outputs_class = self.enc_score_head(output_memory)
enc_outputs_logits = self.enc_score_head(output_memory)
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors

_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
_, topk_ind = torch.topk(enc_outputs_logits.max(-1).values, self.num_queries, dim=1)

reference_points_unact = enc_outputs_coord_unact.gather(
dim=1,
Expand All @@ -560,9 +560,9 @@ def _get_decoder_input(
if denoising_bbox_unact is not None:
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)

enc_topk_logits = enc_outputs_class.gather(
enc_topk_logits = enc_outputs_logits.gather(
dim=1,
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]),
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_logits.shape[-1]),
)

# extract region features
Expand All @@ -575,10 +575,24 @@ def _get_decoder_input(
if denoising_class is not None:
target = torch.concat([denoising_class, target], 1)

return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits, enc_outputs_logits

def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None) -> torch.Tensor:
"""Forward pass of the RTDETRTransformer module."""
def forward(
self,
feats: torch.Tensor,
targets: list[dict[str, torch.Tensor]] | None = None,
explain_mode: bool = False,
) -> dict[str, torch.Tensor]:
"""Forward function of RTDETRTransformer.
Args:
feats (Tensor): Input features.
targets (List[Dict[str, Tensor]]): List of target dictionaries.
explain_mode (bool): Whether to return raw logits for explanation.
Returns:
dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
"""
# input projection and embedding
(memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)

Expand All @@ -596,7 +610,7 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
else:
denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None

target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input(
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits, raw_logits = self._get_decoder_input(
memory,
spatial_shapes,
denoising_class,
Expand Down Expand Up @@ -630,6 +644,9 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
out["dn_meta"] = dn_meta

if explain_mode:
out["raw_logits"] = raw_logits

return out

@torch.jit.unused
Expand Down
34 changes: 14 additions & 20 deletions src/otx/algo/detection/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import copy
import re
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Literal

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -302,23 +302,6 @@ def _optimization_config(self) -> dict[str, Any]:
"""PTQ config for RT-DETR."""
return {"model_type": "transformer"}

@torch.no_grad()
def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function for the score head of the model."""
x = x.squeeze()
return self.model.decoder.dec_score_head[-1](x)

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
from otx.algo.explain.explain_algo import ReciproCAM

explainer = ReciproCAM(
head_forward_fn=self.head_forward_fn,
num_classes=self.num_classes,
optimize_gap=True,
)
return explainer.func

@staticmethod
def _forward_explain_detection(
self, # noqa: ANN001
Expand All @@ -327,10 +310,21 @@ def _forward_explain_detection(
) -> dict[str, torch.Tensor]:
"""Forward function for explainable detection model."""
backbone_feats = self.encoder(self.backbone(entity.images))
predictions = self.decoder(backbone_feats)
predictions = self.decoder(backbone_feats, explain_mode=True)

feature_vector = self.feature_vector_fn(backbone_feats)
saliency_map = self.explain_fn(backbone_feats[-1])

splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]

# Permute and split logits in one line
raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1)

# Reshape each split in a list comprehension
raw_logits = [
logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
]

saliency_map = self.explain_fn(raw_logits)
predictions.update(
{
"feature_vector": feature_vector,
Expand Down

0 comments on commit 018c404

Please sign in to comment.