Skip to content

Commit

Permalink
Merge branch 'develop' into kp/add_semisl
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi committed Aug 9, 2024
2 parents 8542a8d + 4a5e66c commit f424970
Show file tree
Hide file tree
Showing 23 changed files with 707 additions and 1,763 deletions.
22 changes: 17 additions & 5 deletions src/otx/algo/anomaly/padim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@

from typing import TYPE_CHECKING, Literal

from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback
from anomalib.callbacks.post_processor import _PostProcessorCallback
from anomalib.models.image import Padim as AnomalibPadim

from otx.core.model.anomaly import OTXAnomaly
from otx.core.model.base import OTXModel
from otx.core.types.label import AnomalyLabelInfo
from otx.core.types.task import OTXTaskType

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim.optimizer import Optimizer

from otx.core.model.anomaly import AnomalyModelInputs
from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs


class Padim(OTXAnomaly, OTXModel, AnomalibPadim):
class Padim(OTXAnomaly, AnomalibPadim):
"""OTX Padim model.
Args:
Expand All @@ -49,7 +49,6 @@ def __init__(
] = OTXTaskType.ANOMALY_CLASSIFICATION,
) -> None:
OTXAnomaly.__init__(self)
OTXModel.__init__(self, label_info=AnomalyLabelInfo())
AnomalibPadim.__init__(
self,
backbone=backbone,
Expand Down Expand Up @@ -132,3 +131,16 @@ def predict_step(
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def forward(
self,
inputs: AnomalyModelInputs,
) -> AnomalyModelOutputs:
"""Wrap forward method of the Anomalib model."""
outputs = self.validation_step(inputs)
# TODO(Ashwin): update forward implementation to comply with other OTX models
_PostProcessorCallback._post_process(outputs) # noqa: SLF001
_PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001
_MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001

return self._customize_outputs(outputs=outputs, inputs=inputs)
22 changes: 17 additions & 5 deletions src/otx/algo/anomaly/stfpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@

from typing import TYPE_CHECKING, Literal, Sequence

from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback
from anomalib.callbacks.post_processor import _PostProcessorCallback
from anomalib.models.image.stfpm import Stfpm as AnomalibStfpm

from otx.core.model.anomaly import OTXAnomaly
from otx.core.model.base import OTXModel
from otx.core.types.label import AnomalyLabelInfo
from otx.core.types.task import OTXTaskType

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim.optimizer import Optimizer

from otx.core.model.anomaly import AnomalyModelInputs
from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs


class Stfpm(OTXAnomaly, OTXModel, AnomalibStfpm):
class Stfpm(OTXAnomaly, AnomalibStfpm):
"""OTX STFPM model.
Args:
Expand All @@ -46,7 +46,6 @@ def __init__(
**kwargs,
) -> None:
OTXAnomaly.__init__(self)
OTXModel.__init__(self, label_info=AnomalyLabelInfo())
AnomalibStfpm.__init__(
self,
backbone=backbone,
Expand Down Expand Up @@ -124,3 +123,16 @@ def predict_step(
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibStfpm.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def forward(
self,
inputs: AnomalyModelInputs,
) -> AnomalyModelOutputs:
"""Wrap forward method of the Anomalib model."""
outputs = self.validation_step(inputs)
# TODO(Ashwin): update forward implementation to comply with other OTX models
_PostProcessorCallback._post_process(outputs) # noqa: SLF001
_PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001
_MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001

return self._customize_outputs(outputs=outputs, inputs=inputs)
4 changes: 2 additions & 2 deletions src/otx/algo/classification/backbones/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def init_weights(self, pretrained: bool | str | None = None) -> None:
checkpoint = torch.load(pretrained, None)
load_checkpoint_to_model(self, checkpoint)
print(f"init weight - {pretrained}")
elif pretrained is not None:
elif pretrained:
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
download_model(net=self, model_name=self.model_name, local_model_store_dir_path=str(cache_dir))
print(f"init weight - {pretrained_urls[self.model_name]}")
print(f"Download model weight in {cache_dir!s}")
Loading

0 comments on commit f424970

Please sign in to comment.