diff --git a/external/anomaly/ote_anomalib/callbacks/__init__.py b/external/anomaly/ote_anomalib/callbacks/__init__.py index 4d4fc9b2e79..c6dbaa7c72d 100644 --- a/external/anomaly/ote_anomalib/callbacks/__init__.py +++ b/external/anomaly/ote_anomalib/callbacks/__init__.py @@ -18,5 +18,6 @@ from .inference import AnomalyInferenceCallback from .progress import ProgressCallback +from .score_report import ScoreReportingCallback -__all__ = ["AnomalyInferenceCallback", "ProgressCallback"] +__all__ = ["AnomalyInferenceCallback", "ProgressCallback", "ScoreReportingCallback"] diff --git a/external/anomaly/ote_anomalib/callbacks/score_report.py b/external/anomaly/ote_anomalib/callbacks/score_report.py new file mode 100644 index 00000000000..bfdf6671f37 --- /dev/null +++ b/external/anomaly/ote_anomalib/callbacks/score_report.py @@ -0,0 +1,43 @@ +"""Score reporting callback""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Optional + +from ote_sdk.entities.train_parameters import TrainParameters +from pytorch_lightning import Callback + + +class ScoreReportingCallback(Callback): + """ + Callback for reporting score. + """ + + def __init__(self, parameters: Optional[TrainParameters] = None) -> None: + if parameters is not None: + self.score_reporting_callback = parameters.update_progress + else: + self.score_reporting_callback = None + + def on_validation_epoch_end(self, trainer, pl_module): + """ + If score exists in trainer.logged_metrics, report the score. + """ + if self.score_reporting_callback is not None: + score = None + metric = getattr(self.score_reporting_callback, 'metric', None) + if metric in trainer.logged_metrics: + score = float(trainer.logged_metrics[metric]) + self.score_reporting_callback(progress=0, score=score) diff --git a/external/anomaly/ote_anomalib/train_task.py b/external/anomaly/ote_anomalib/train_task.py index d2de58ab9dd..28d99d69404 100644 --- a/external/anomaly/ote_anomalib/train_task.py +++ b/external/anomaly/ote_anomalib/train_task.py @@ -16,7 +16,7 @@ from anomalib.utils.callbacks import MinMaxNormalizationCallback from ote_anomalib import AnomalyInferenceTask -from ote_anomalib.callbacks import ProgressCallback +from ote_anomalib.callbacks import ProgressCallback, ScoreReportingCallback from ote_anomalib.data import OTEAnomalyDataModule from ote_anomalib.logging import get_logger from ote_sdk.entities.datasets import DatasetEntity @@ -50,7 +50,11 @@ def train( logger.info("Training Configs '%s'", config) datamodule = OTEAnomalyDataModule(config=config, dataset=dataset, task_type=self.task_type) - callbacks = [ProgressCallback(parameters=train_parameters), MinMaxNormalizationCallback()] + callbacks = [ + ProgressCallback(parameters=train_parameters), + MinMaxNormalizationCallback(), + ScoreReportingCallback(parameters=train_parameters) + ] self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks) self.trainer.fit(model=self.model, datamodule=datamodule)