From 49c503daee5ec328e292d226e8cd6a172021ba00 Mon Sep 17 00:00:00 2001 From: Soobee Lee Date: Thu, 29 Dec 2022 16:20:30 +0900 Subject: [PATCH] Fix semisl trainer import (#1471) Co-authored-by: Lee, Soobee --- otx/mpa/det/incremental/stage.py | 7 ++++--- otx/mpa/det/semisl/__init__.py | 5 +++++ otx/mpa/det/semisl/stage.py | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/otx/mpa/det/incremental/stage.py b/otx/mpa/det/incremental/stage.py index 2c760a79e3a..e591355c904 100644 --- a/otx/mpa/det/incremental/stage.py +++ b/otx/mpa/det/incremental/stage.py @@ -7,6 +7,7 @@ from mmcv import ConfigDict from mmdet.datasets import build_dataset +from otx.mpa.stage import Stage from otx.mpa.det.stage import DetectionStage from otx.mpa.utils.config_utils import update_or_add_custom_hook from otx.mpa.utils.logger import get_logger @@ -115,7 +116,7 @@ def configure_anchor(self, cfg, proposal_ratio=None): def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model_classes): """Patch config for incremental learning""" if task_adapt_type == "mpa": - self.configure_bbox_head(cfg, model_classes) + self.configure_bbox_head(cfg, org_model_classes, model_classes) self.configure_task_adapt_hook(cfg, org_model_classes, model_classes) self.configure_ema(cfg) self.configure_val_interval(cfg) @@ -123,7 +124,7 @@ def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model src_data_cfg = self.get_data_cfg(cfg, "train") src_data_cfg.pop("old_new_indices", None) - def configure_bbox_head(self, cfg, model_classes): + def configure_bbox_head(self, cfg, org_model_classes, model_classes): """Patch bbox head in detector for class incremental learning. Most of patching are related with hyper-params in focal loss """ @@ -225,7 +226,7 @@ def get_img_ids_for_incr(cfg, org_model_classes, model_classes): new_classes = np.setdiff1d(model_classes, org_model_classes).tolist() old_classes = np.intersect1d(org_model_classes, model_classes).tolist() - src_data_cfg = self.get_data_cfg(cfg, "train") + src_data_cfg = Stage.get_data_cfg(cfg, "train") ids_old, ids_new = [], [] data_cfg = cfg.data.test.copy() diff --git a/otx/mpa/det/semisl/__init__.py b/otx/mpa/det/semisl/__init__.py index 1e19f1159d9..5b83a3f3c05 100644 --- a/otx/mpa/det/semisl/__init__.py +++ b/otx/mpa/det/semisl/__init__.py @@ -1,3 +1,8 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # + +from .inferrer import SemiSLDetectionInferrer +from .trainer import SemiSLDetectionTrainer + +__all__ = ["SemiSLDetectionInferrer", "SemiSLDetectionTrainer"] diff --git a/otx/mpa/det/semisl/stage.py b/otx/mpa/det/semisl/stage.py index 0330b1f88a3..ac921aa5ea0 100644 --- a/otx/mpa/det/semisl/stage.py +++ b/otx/mpa/det/semisl/stage.py @@ -45,7 +45,7 @@ def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model Semi supervised learning should support incrmental learning """ if task_adapt_type == "mpa": - self.configure_bbox_head(cfg, model_classes) + self.configure_bbox_head(cfg, org_model_classes, model_classes) self.configure_task_adapt_hook(cfg, org_model_classes, model_classes) self.configure_val_interval(cfg) else: