From 65ddbfacaf401ee447a9e3e4984ac8c483b0407d Mon Sep 17 00:00:00 2001 From: Sungman Cho <sungman.cho@intel.com> Date: Tue, 24 Oct 2023 13:11:40 +0900 Subject: [PATCH] Fix the CustomNonLinearClsHead when the batch_size is set to 1 (#2571) Fix bn1d issue Co-authored-by: sungmanc <sungmanc@intel.com> --- .../adapters/mmcls/models/heads/custom_cls_head.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py index fcf2008e795..ec760df7f50 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py +++ b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py @@ -42,6 +42,10 @@ def forward(self, x): def forward_train(self, cls_score, gt_label): """Forward_train fuction of CustomNonLinearHead class.""" + bs = cls_score.shape[0] + if bs == 1: + cls_score = torch.cat([cls_score, cls_score], dim=0) + gt_label = torch.cat([gt_label, gt_label], dim=0) logit = self.classifier(cls_score) losses = self.loss(logit, gt_label, feature=cls_score) return losses