From 65ddbfacaf401ee447a9e3e4984ac8c483b0407d Mon Sep 17 00:00:00 2001
From: Sungman Cho <sungman.cho@intel.com>
Date: Tue, 24 Oct 2023 13:11:40 +0900
Subject: [PATCH] Fix the CustomNonLinearClsHead when the batch_size is set to
 1 (#2571)

Fix bn1d issue

Co-authored-by: sungmanc <sungmanc@intel.com>
---
 .../adapters/mmcls/models/heads/custom_cls_head.py            | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py
index fcf2008e795..ec760df7f50 100644
--- a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py
+++ b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py
@@ -42,6 +42,10 @@ def forward(self, x):
 
     def forward_train(self, cls_score, gt_label):
         """Forward_train fuction of CustomNonLinearHead class."""
+        bs = cls_score.shape[0]
+        if bs == 1:
+            cls_score = torch.cat([cls_score, cls_score], dim=0)
+            gt_label = torch.cat([gt_label, gt_label], dim=0)
         logit = self.classifier(cls_score)
         losses = self.loss(logit, gt_label, feature=cls_score)
         return losses