Skip to content

Commit

Permalink
Fix draem
Browse files Browse the repository at this point in the history
Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim committed Mar 19, 2024
1 parent 6558e7b commit a1f5967
Showing 1 changed file with 78 additions and 2 deletions.
80 changes: 78 additions & 2 deletions src/otx/algo/anomaly/draem.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
"""OTX Draem model."""
# TODO(someone): Revisit mypy errors after OTXLitModule deprecation and anomaly refactoring
# mypy: ignore-errors

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import TYPE_CHECKING

from anomalib.models.image import Draem as AnomalibDraem

from otx.core.model.entity.base import OTXModel
from otx.core.model.module.anomaly import OTXAnomaly
from otx.core.model.anomaly import OTXAnomaly
from otx.core.model.base import OTXModel

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim.optimizer import Optimizer

from otx.core.model.anomaly import AnomalyModelInputs


class Draem(OTXAnomaly, OTXModel, AnomalibDraem):
Expand Down Expand Up @@ -40,3 +50,69 @@ def __init__(
anomaly_source_path=anomaly_source_path,
beta=beta,
)

def configure_metric(self) -> None:
"""This does not follow OTX metric configuration."""
return

def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None:
"""STFPM does not follow OTX optimizer configuration."""
return Draem.configure_optimizers(self)

def on_validation_epoch_start(self) -> None:
"""Callback triggered when the validation epoch starts."""
Draem.on_validation_epoch_start(self)

def on_test_epoch_start(self) -> None:
"""Callback triggered when the test epoch starts."""
Draem.on_test_epoch_start(self)

def on_validation_epoch_end(self) -> None:
"""Callback triggered when the validation epoch ends."""
Draem.on_validation_epoch_end(self)

def on_test_epoch_end(self) -> None:
"""Callback triggered when the test epoch ends."""
Draem.on_test_epoch_end(self)

def training_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call training step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return Draem.training_step(self, inputs, batch_idx) # type: ignore[misc]

def validation_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call validation step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return Draem.validation_step(self, inputs, batch_idx) # type: ignore[misc]

def test_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return Draem.test_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def predict_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return Draem.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

0 comments on commit a1f5967

Please sign in to comment.