From f72829518a5c397e425d71cfa8334cccca549035 Mon Sep 17 00:00:00 2001
From: Eunwoo Shin <eunwoo.shin@intel.com>
Date: Fri, 6 Jan 2023 13:27:37 +0900
Subject: [PATCH] [OTX] Evaluate a model before training starts (#1472)

* add eval_before_train_hook

* fix for pre-comimt test
---
 otx/algorithms/common/tasks/training_base.py  |  3 +
 otx/mpa/modules/hooks/__init__.py             |  1 +
 .../modules/hooks/eval_before_train_hook.py   | 56 +++++++++++++++++++
 3 files changed, 60 insertions(+)
 create mode 100644 otx/mpa/modules/hooks/eval_before_train_hook.py

diff --git a/otx/algorithms/common/tasks/training_base.py b/otx/algorithms/common/tasks/training_base.py
index cdacbdad697..155856c005d 100644
--- a/otx/algorithms/common/tasks/training_base.py
+++ b/otx/algorithms/common/tasks/training_base.py
@@ -242,6 +242,9 @@ def _initialize(self, export=False):
             # to disenable early stopping during self-sl
             self.set_early_stopping_hook()
 
+        # add eval before train hook
+        update_or_add_custom_hook(self._recipe_cfg, ConfigDict(type="EvalBeforeTrainHook", priority="ABOVE_NORMAL"))
+
         # add Cancel tranining hook
         update_or_add_custom_hook(
             self._recipe_cfg,
diff --git a/otx/mpa/modules/hooks/__init__.py b/otx/mpa/modules/hooks/__init__.py
index 406b5d8fa50..71f32717600 100644
--- a/otx/mpa/modules/hooks/__init__.py
+++ b/otx/mpa/modules/hooks/__init__.py
@@ -7,6 +7,7 @@
     adaptive_training_hooks,
     checkpoint_hook,
     early_stopping_hook,
+    eval_before_train_hook,
     fp16_sam_optimizer_hook,
     gpu_monitor,
     hpo_hook,
diff --git a/otx/mpa/modules/hooks/eval_before_train_hook.py b/otx/mpa/modules/hooks/eval_before_train_hook.py
new file mode 100644
index 00000000000..3d5a7242156
--- /dev/null
+++ b/otx/mpa/modules/hooks/eval_before_train_hook.py
@@ -0,0 +1,56 @@
+# Copyright (C) 2022 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+from mmcv.runner import HOOKS
+from mmcv.runner import EvalHook as SegEvalHook
+from mmcv.runner import Hook
+
+from otx.mpa.modules.hooks.checkpoint_hook import CheckpointHookWithValResults
+from otx.mpa.modules.hooks.eval_hook import CustomEvalHook as ClsEvalHook
+
+try:
+    from mmdet.core.evaluation.eval_hooks import EvalHook as DetEvalHook
+except ImportError:
+    DetEvalHook = None
+
+
+@HOOKS.register_module()
+class EvalBeforeTrainHook(Hook):
+    """Hook to evaluate and save model weight before training."""
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self._executed = False
+
+    def before_train_epoch(self, runner):
+        """Execute the evaluation hook before training"""
+        if not self._executed:
+            for hook in runner.hooks:
+                if self.check_eval_hook(hook):
+                    self.execute_hook(hook, runner)
+                    if not issubclass(type(hook), ClsEvalHook):
+                        break
+
+                # cls task saves the model weight in a checkpoint hook after eval hook is executed
+                if issubclass(type(hook), CheckpointHookWithValResults):
+                    self.execute_hook(hook, runner)
+                    break
+
+            self._executed = True
+
+    @staticmethod
+    def check_eval_hook(hook: Hook):
+        """Check that the hook is an evaluation hook."""
+        eval_hook_types = (ClsEvalHook, SegEvalHook)
+        if DetEvalHook is not None:
+            eval_hook_types += (DetEvalHook,)
+        return issubclass(type(hook), eval_hook_types)
+
+    @staticmethod
+    def execute_hook(hook: Hook, runner):
+        """Execute after_train_epoch or iter depending on `by_epoch` value"""
+        if getattr(hook, "by_epoch", True):
+            hook.after_train_epoch(runner)
+        else:
+            hook.after_train_iter(runner)