From f72829518a5c397e425d71cfa8334cccca549035 Mon Sep 17 00:00:00 2001 From: Eunwoo Shin Date: Fri, 6 Jan 2023 13:27:37 +0900 Subject: [PATCH] [OTX] Evaluate a model before training starts (#1472) * add eval_before_train_hook * fix for pre-comimt test --- otx/algorithms/common/tasks/training_base.py | 3 + otx/mpa/modules/hooks/__init__.py | 1 + .../modules/hooks/eval_before_train_hook.py | 56 +++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 otx/mpa/modules/hooks/eval_before_train_hook.py diff --git a/otx/algorithms/common/tasks/training_base.py b/otx/algorithms/common/tasks/training_base.py index cdacbdad697..155856c005d 100644 --- a/otx/algorithms/common/tasks/training_base.py +++ b/otx/algorithms/common/tasks/training_base.py @@ -242,6 +242,9 @@ def _initialize(self, export=False): # to disenable early stopping during self-sl self.set_early_stopping_hook() + # add eval before train hook + update_or_add_custom_hook(self._recipe_cfg, ConfigDict(type="EvalBeforeTrainHook", priority="ABOVE_NORMAL")) + # add Cancel tranining hook update_or_add_custom_hook( self._recipe_cfg, diff --git a/otx/mpa/modules/hooks/__init__.py b/otx/mpa/modules/hooks/__init__.py index 406b5d8fa50..71f32717600 100644 --- a/otx/mpa/modules/hooks/__init__.py +++ b/otx/mpa/modules/hooks/__init__.py @@ -7,6 +7,7 @@ adaptive_training_hooks, checkpoint_hook, early_stopping_hook, + eval_before_train_hook, fp16_sam_optimizer_hook, gpu_monitor, hpo_hook, diff --git a/otx/mpa/modules/hooks/eval_before_train_hook.py b/otx/mpa/modules/hooks/eval_before_train_hook.py new file mode 100644 index 00000000000..3d5a7242156 --- /dev/null +++ b/otx/mpa/modules/hooks/eval_before_train_hook.py @@ -0,0 +1,56 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from mmcv.runner import HOOKS +from mmcv.runner import EvalHook as SegEvalHook +from mmcv.runner import Hook + +from otx.mpa.modules.hooks.checkpoint_hook import CheckpointHookWithValResults +from otx.mpa.modules.hooks.eval_hook import CustomEvalHook as ClsEvalHook + +try: + from mmdet.core.evaluation.eval_hooks import EvalHook as DetEvalHook +except ImportError: + DetEvalHook = None + + +@HOOKS.register_module() +class EvalBeforeTrainHook(Hook): + """Hook to evaluate and save model weight before training.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._executed = False + + def before_train_epoch(self, runner): + """Execute the evaluation hook before training""" + if not self._executed: + for hook in runner.hooks: + if self.check_eval_hook(hook): + self.execute_hook(hook, runner) + if not issubclass(type(hook), ClsEvalHook): + break + + # cls task saves the model weight in a checkpoint hook after eval hook is executed + if issubclass(type(hook), CheckpointHookWithValResults): + self.execute_hook(hook, runner) + break + + self._executed = True + + @staticmethod + def check_eval_hook(hook: Hook): + """Check that the hook is an evaluation hook.""" + eval_hook_types = (ClsEvalHook, SegEvalHook) + if DetEvalHook is not None: + eval_hook_types += (DetEvalHook,) + return issubclass(type(hook), eval_hook_types) + + @staticmethod + def execute_hook(hook: Hook, runner): + """Execute after_train_epoch or iter depending on `by_epoch` value""" + if getattr(hook, "by_epoch", True): + hook.after_train_epoch(runner) + else: + hook.after_train_iter(runner)