Skip to content

Commit

Permalink
Fix semisl trainer import (#1471)
Browse files Browse the repository at this point in the history
Co-authored-by: Lee, Soobee <soobeele@intel.com>
  • Loading branch information
2 people authored and sungmanc committed Dec 29, 2022
1 parent 6e86819 commit 49c503d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
7 changes: 4 additions & 3 deletions otx/mpa/det/incremental/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mmcv import ConfigDict
from mmdet.datasets import build_dataset

from otx.mpa.stage import Stage
from otx.mpa.det.stage import DetectionStage
from otx.mpa.utils.config_utils import update_or_add_custom_hook
from otx.mpa.utils.logger import get_logger
Expand Down Expand Up @@ -115,15 +116,15 @@ def configure_anchor(self, cfg, proposal_ratio=None):
def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model_classes):
"""Patch config for incremental learning"""
if task_adapt_type == "mpa":
self.configure_bbox_head(cfg, model_classes)
self.configure_bbox_head(cfg, org_model_classes, model_classes)
self.configure_task_adapt_hook(cfg, org_model_classes, model_classes)
self.configure_ema(cfg)
self.configure_val_interval(cfg)
else:
src_data_cfg = self.get_data_cfg(cfg, "train")
src_data_cfg.pop("old_new_indices", None)

def configure_bbox_head(self, cfg, model_classes):
def configure_bbox_head(self, cfg, org_model_classes, model_classes):
"""Patch bbox head in detector for class incremental learning.
Most of patching are related with hyper-params in focal loss
"""
Expand Down Expand Up @@ -225,7 +226,7 @@ def get_img_ids_for_incr(cfg, org_model_classes, model_classes):
new_classes = np.setdiff1d(model_classes, org_model_classes).tolist()
old_classes = np.intersect1d(org_model_classes, model_classes).tolist()

src_data_cfg = self.get_data_cfg(cfg, "train")
src_data_cfg = Stage.get_data_cfg(cfg, "train")

ids_old, ids_new = [], []
data_cfg = cfg.data.test.copy()
Expand Down
5 changes: 5 additions & 0 deletions otx/mpa/det/semisl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from .inferrer import SemiSLDetectionInferrer
from .trainer import SemiSLDetectionTrainer

__all__ = ["SemiSLDetectionInferrer", "SemiSLDetectionTrainer"]
2 changes: 1 addition & 1 deletion otx/mpa/det/semisl/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model
Semi supervised learning should support incrmental learning
"""
if task_adapt_type == "mpa":
self.configure_bbox_head(cfg, model_classes)
self.configure_bbox_head(cfg, org_model_classes, model_classes)
self.configure_task_adapt_hook(cfg, org_model_classes, model_classes)
self.configure_val_interval(cfg)
else:
Expand Down

0 comments on commit 49c503d

Please sign in to comment.