Skip to content

Commit

Permalink
[OTX] Evaluate a model before training starts (openvinotoolkit#1472)
Browse files Browse the repository at this point in the history
* add eval_before_train_hook

* fix for pre-comimt test
  • Loading branch information
eunwoosh authored Jan 6, 2023
1 parent 2c9c4f3 commit f728295
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
3 changes: 3 additions & 0 deletions otx/algorithms/common/tasks/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def _initialize(self, export=False):
# to disenable early stopping during self-sl
self.set_early_stopping_hook()

# add eval before train hook
update_or_add_custom_hook(self._recipe_cfg, ConfigDict(type="EvalBeforeTrainHook", priority="ABOVE_NORMAL"))

# add Cancel tranining hook
update_or_add_custom_hook(
self._recipe_cfg,
Expand Down
1 change: 1 addition & 0 deletions otx/mpa/modules/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
adaptive_training_hooks,
checkpoint_hook,
early_stopping_hook,
eval_before_train_hook,
fp16_sam_optimizer_hook,
gpu_monitor,
hpo_hook,
Expand Down
56 changes: 56 additions & 0 deletions otx/mpa/modules/hooks/eval_before_train_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from mmcv.runner import HOOKS
from mmcv.runner import EvalHook as SegEvalHook
from mmcv.runner import Hook

from otx.mpa.modules.hooks.checkpoint_hook import CheckpointHookWithValResults
from otx.mpa.modules.hooks.eval_hook import CustomEvalHook as ClsEvalHook

try:
from mmdet.core.evaluation.eval_hooks import EvalHook as DetEvalHook
except ImportError:
DetEvalHook = None


@HOOKS.register_module()
class EvalBeforeTrainHook(Hook):
"""Hook to evaluate and save model weight before training."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._executed = False

def before_train_epoch(self, runner):
"""Execute the evaluation hook before training"""
if not self._executed:
for hook in runner.hooks:
if self.check_eval_hook(hook):
self.execute_hook(hook, runner)
if not issubclass(type(hook), ClsEvalHook):
break

# cls task saves the model weight in a checkpoint hook after eval hook is executed
if issubclass(type(hook), CheckpointHookWithValResults):
self.execute_hook(hook, runner)
break

self._executed = True

@staticmethod
def check_eval_hook(hook: Hook):
"""Check that the hook is an evaluation hook."""
eval_hook_types = (ClsEvalHook, SegEvalHook)
if DetEvalHook is not None:
eval_hook_types += (DetEvalHook,)
return issubclass(type(hook), eval_hook_types)

@staticmethod
def execute_hook(hook: Hook, runner):
"""Execute after_train_epoch or iter depending on `by_epoch` value"""
if getattr(hook, "by_epoch", True):
hook.after_train_epoch(runner)
else:
hook.after_train_iter(runner)

0 comments on commit f728295

Please sign in to comment.