Skip to content

Commit

Permalink
make new callback function
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed May 31, 2022
1 parent 4ea9d44 commit 4fe2598
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 12 deletions.
3 changes: 2 additions & 1 deletion external/anomaly/ote_anomalib/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@

from .inference import AnomalyInferenceCallback
from .progress import ProgressCallback
from .score_report import ScoreReportingCallback

__all__ = ["AnomalyInferenceCallback", "ProgressCallback"]
__all__ = ["AnomalyInferenceCallback", "ProgressCallback", "ScoreReportingCallback"]
9 changes: 0 additions & 9 deletions external/anomaly/ote_anomalib/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,6 @@ def _get_progress(self, stage: str = "train") -> float:
raise ValueError(f"Unknown stage {stage}. Available: train, predict and test")

return self._progress

def on_train_epoch_end(self, trainer, pl_module):
super().on_train_epoch_end(trainer, pl_module)
score = None
metric = getattr(self.update_progress_callback, 'metric', None)
if metric in trainer.logged_metrics:
score = float(trainer.logged_metrics[metric])
progress = int(self._get_progress('train'))
self.update_progress_callback(progress=progress, score=score)

def _update_progress(self, stage: str):
progress = self._get_progress(stage)
Expand Down
43 changes: 43 additions & 0 deletions external/anomaly/ote_anomalib/callbacks/score_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Score reporting callback"""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Optional

from ote_sdk.entities.train_parameters import TrainParameters
from pytorch_lightning import Callback


class ScoreReportingCallback(Callback):
"""
Callback for reporting score.
"""

def __init__(self, parameters: Optional[TrainParameters] = None) -> None:
if parameters is not None:
self.score_reporting_callback = parameters.update_progress
else:
self.score_reporting_callback = None

def on_train_epoch_end(self, trainer, pl_module):
"""
If score exists in trainer.logged_metrics, report the score.
"""
if self.score_reporting_callback is not None:
score = None
metric = getattr(self.score_reporting_callback, 'metric', None)
if metric in trainer.logged_metrics:
score = float(trainer.logged_metrics[metric])
self.score_reporting_callback(progress=0, score=score)
8 changes: 6 additions & 2 deletions external/anomaly/ote_anomalib/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from anomalib.utils.callbacks import MinMaxNormalizationCallback
from ote_anomalib import AnomalyInferenceTask
from ote_anomalib.callbacks import ProgressCallback
from ote_anomalib.callbacks import ProgressCallback, ScoreReportingCallback
from ote_anomalib.data import OTEAnomalyDataModule
from ote_anomalib.logging import get_logger
from ote_sdk.entities.datasets import DatasetEntity
Expand Down Expand Up @@ -50,7 +50,11 @@ def train(
logger.info("Training Configs '%s'", config)

datamodule = OTEAnomalyDataModule(config=config, dataset=dataset, task_type=self.task_type)
callbacks = [ProgressCallback(parameters=train_parameters), MinMaxNormalizationCallback()]
callbacks = [
ProgressCallback(parameters=train_parameters),
MinMaxNormalizationCallback(),
ScoreReportingCallback(parameters=train_parameters)
]

self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks)
self.trainer.fit(model=self.model, datamodule=datamodule)
Expand Down

0 comments on commit 4fe2598

Please sign in to comment.