Skip to content

Commit

Permalink
Sanitize export outputs for vits
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Aug 2, 2024
1 parent d977505 commit 5a50259
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/otx/algo/classification/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor:
x = self.model.backbone.norm(x)
if self.model.neck is not None:
x = self.model.neck(x)

# Head
cls_token = x[:, 0]
layer_output = [None, cls_token]
Expand Down Expand Up @@ -140,14 +139,18 @@ def _forward_explain_image_classifier(
scores = pred_results.unbind(0)
labels = logits.argmax(-1, keepdim=True).unbind(0)

return {
outputs = {
"logits": logits,
"feature_vector": feature_vector,
"saliency_map": saliency_map,
"scores": scores,
"labels": labels,
}

if not torch._C._is_tracing():
outputs["scores"] = scores
outputs["labels"] = labels

return outputs

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
explainer = ViTReciproCAM(
Expand Down

0 comments on commit 5a50259

Please sign in to comment.