Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OTX/Anomaly] Add changes from external to otx #1452

Merged
merged 5 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,5 @@

from .inference import AnomalyInferenceCallback
from .progress import ProgressCallback
from .score_report import ScoreReportingCallback

__all__ = ["AnomalyInferenceCallback", "ProgressCallback", "ScoreReportingCallback"]
__all__ = ["AnomalyInferenceCallback", "ProgressCallback"]
27 changes: 23 additions & 4 deletions otx/algorithms/anomaly/adapters/anomalib/callbacks/progress.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Progressbar Callback for OTX task."""
"""Progressbar and Score Reporting callback Callback for OTX task.

TODO Since only one progressbar callback is supported HPO is combined into one callback. Remove this after the refactor
"""

# Copyright (C) 2021 Intel Corporation
#
Expand Down Expand Up @@ -38,9 +41,9 @@ def __init__(
self._progress: float = 0

if parameters is not None:
self.update_progress_callback = parameters.update_progress
self.progress_and_hpo_callback = parameters.update_progress
else:
self.update_progress_callback = default_progress_callback
self.progress_and_hpo_callback = default_progress_callback

def on_train_start(self, trainer, pl_module):
"""Store max epochs and current epoch from trainer."""
Expand Down Expand Up @@ -75,6 +78,22 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update_progress(stage="test")

def on_validation_epoch_end(self, trainer, pl_module): # pylint: disable=unused-argument
"""If score exists in trainer.logged_metrics, report the score."""
if self.progress_and_hpo_callback is not None:
score = None
metric = getattr(self.progress_and_hpo_callback, "metric", None)
print(f"[DEBUG-HPO] logged_metrics = {trainer.logged_metrics}")
if metric in trainer.logged_metrics:
score = float(trainer.logged_metrics[metric])
if score < 1.0:
score = score + int(trainer.global_step)
else:
score = -(score + int(trainer.global_step))

# Always assumes that hpo validation step is called during training.
self.progress_and_hpo_callback(int(self._get_progress("train")), score) # pylint: disable=not-callable

def _reset_progress(self):
self._progress = 0.0

Expand Down Expand Up @@ -104,4 +123,4 @@ def _get_progress(self, stage: str = "train") -> float:

def _update_progress(self, stage: str):
progress = self._get_progress(stage)
self.update_progress_callback(int(progress), None)
self.progress_and_hpo_callback(int(progress), None)
42 changes: 0 additions & 42 deletions otx/algorithms/anomaly/adapters/anomalib/callbacks/score_report.py

This file was deleted.

11 changes: 11 additions & 0 deletions otx/algorithms/anomaly/configs/base/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from attr import attrs

from otx.algorithms.anomaly.configs.base.configuration_enums import (
ModelBackbone,
goodsong81 marked this conversation as resolved.
Show resolved Hide resolved
POTQuantizationPreset,
)
from otx.api.configuration import ConfigurableParameters
Expand Down Expand Up @@ -48,6 +49,16 @@ class LearningParameters(ParameterGroup):
header = string_attribute("Learning Parameters")
description = header

# Editable is set to false as WideResNet50 is very large for
# onnx's protobuf (2gb) limit. This ends up crashing the export.
backbone = selectable(
goodsong81 marked this conversation as resolved.
Show resolved Hide resolved
default_value=ModelBackbone.RESNET18,
header="Model Backbone",
description="Pre-trained backbone used for feature extraction",
editable=False,
visible_in_ui=False,
)

train_batch_size = configurable_integer(
default_value=32,
min_value=1,
Expand Down
30 changes: 16 additions & 14 deletions otx/algorithms/anomaly/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def __init__(self, task_environment: TaskEnvironment) -> None:

def get_config(self) -> Union[DictConfig, ListConfig]:
"""Get Anomalib Config from task environment.

Returns:
Union[DictConfig, ListConfig]: Anomalib config.
"""
Expand All @@ -115,21 +114,18 @@ def get_config(self) -> Union[DictConfig, ListConfig]:

def load_model(self, otx_model: Optional[ModelEntity]) -> AnomalyModule:
"""Create and Load Anomalib Module from OTX Model.

This method checks if the task environment has a saved OTX Model,
and creates one. If the OTX model already exists, it returns the
the model with the saved weights.

Args:
otx_model (Optional[ModelEntity]): OTX Model from the
task environment.

Returns:
AnomalyModule: Anomalib
classification or segmentation model with/without weights.
"""
model = get_model(config=self.config)
if otx_model is None:
model = get_model(config=self.config)
logger.info(
"No trained model in project yet. Created new model with '%s'",
self.model_name,
Expand All @@ -138,18 +134,23 @@ def load_model(self, otx_model: Optional[ModelEntity]) -> AnomalyModule:
buffer = io.BytesIO(otx_model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))

if model_data["config"]["model"]["backbone"] != self.config["model"]["backbone"]:
logger.warning(
"Backbone of the model in the Task Environment is different from the one in the template. "
f"creating model with backbone={model_data['config']['model']['backbone']}"
)
self.config["model"]["backbone"] = model_data["config"]["model"]["backbone"]
try:
model = get_model(config=self.config)
model.load_state_dict(model_data["model"])
logger.info("Loaded model weights from Task Environment")

except BaseException as exception:
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception

return model

def cancel_training(self) -> None:
"""Cancel the training `after_batch_end`.

This terminates the training; however validation is still performed.
"""
logger.info("Cancel training requested.")
Expand Down Expand Up @@ -195,7 +196,6 @@ def infer(self, dataset: DatasetEntity, inference_parameters: InferenceParameter

def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None) -> None:
"""Evaluate the performance on a result set.

Args:
output_resultset (ResultSetEntity): Result Set from which the performance is evaluated.
evaluation_metric (Optional[str], optional): Evaluation metric. Defaults to None. Instead,
Expand Down Expand Up @@ -251,8 +251,8 @@ def export(self, export_type: ExportType, output_model: ModelEntity) -> None:
logger.info("Exporting the OpenVINO model.")
onnx_path = os.path.join(self.config.project.path, "onnx_model.onnx")
self._export_to_onnx(onnx_path)
optimize_command = "mo --input_model " + onnx_path + " --output_dir " + self.config.project.path
subprocess.call(optimize_command, shell=True)
optimize_command = ["mo", "--input_model", onnx_path, "--output_dir", self.config.project.path]
subprocess.run(optimize_command, check=True)
bin_file = glob(os.path.join(self.config.project.path, "*.bin"))[0]
xml_file = glob(os.path.join(self.config.project.path, "*.xml"))[0]
with open(bin_file, "rb") as file:
Expand All @@ -266,7 +266,7 @@ def export(self, export_type: ExportType, output_model: ModelEntity) -> None:
output_model.set_data("label_schema.json", label_schema_to_bytes(self.task_environment.label_schema))
self._set_metadata(output_model)

def _model_info(self) -> Dict:
def model_info(self) -> Dict:
"""Return model info to save the model weights.

Returns:
Expand All @@ -285,7 +285,7 @@ def save_model(self, output_model: ModelEntity) -> None:
output_model (ModelEntity): Output model onto which the weights are saved.
"""
logger.info("Saving the model weights.")
model_info = self._model_info()
model_info = self.model_info()
buffer = io.BytesIO()
torch.save(model_info, buffer)
output_model.set_data("weights.pth", buffer.getvalue())
Expand All @@ -301,8 +301,10 @@ def save_model(self, output_model: ModelEntity) -> None:
output_model.optimization_methods = self.optimization_methods

def _set_metadata(self, output_model: ModelEntity):
output_model.set_data("image_threshold", self.model.image_threshold.value.cpu().numpy().tobytes())
output_model.set_data("pixel_threshold", self.model.pixel_threshold.value.cpu().numpy().tobytes())
if hasattr(self.model, "image_threshold"):
output_model.set_data("image_threshold", self.model.image_threshold.value.cpu().numpy().tobytes())
if hasattr(self.model, "pixel_threshold"):
output_model.set_data("pixel_threshold", self.model.pixel_threshold.value.cpu().numpy().tobytes())
if hasattr(self.model, "normalization_metrics"):
output_model.set_data("min", self.model.normalization_metrics.state_dict()["min"].cpu().numpy().tobytes())
output_model.set_data("max", self.model.normalization_metrics.state_dict()["max"].cpu().numpy().tobytes())
Expand Down
2 changes: 1 addition & 1 deletion otx/algorithms/anomaly/tasks/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def optimize(

logger.info("Training completed.")

def _model_info(self) -> Dict:
def model_info(self) -> Dict:
"""Return model info to save the model weights.

Returns:
Expand Down
46 changes: 40 additions & 6 deletions otx/algorithms/anomaly/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import io
from typing import Optional

import torch
from anomalib.models import AnomalyModule, get_model
from anomalib.utils.callbacks import (
MetricsConfigurationCallback,
MinMaxNormalizationCallback,
)
from pytorch_lightning import Trainer, seed_everything

from otx.algorithms.anomaly.adapters.anomalib.callbacks import (
ProgressCallback,
ScoreReportingCallback,
)
from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback
from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule
from otx.algorithms.anomaly.adapters.anomalib.logger import get_logger
from otx.api.entities.datasets import DatasetEntity
Expand All @@ -49,7 +49,6 @@ def train(
seed: Optional[int] = None,
) -> None:
"""Train the anomaly classification model.

Args:
dataset (DatasetEntity): Input dataset.
output_model (ModelEntity): Output model to save the model weights.
Expand All @@ -72,7 +71,6 @@ def train(
callbacks = [
ProgressCallback(parameters=train_parameters),
MinMaxNormalizationCallback(),
ScoreReportingCallback(parameters=train_parameters),
MetricsConfigurationCallback(
adaptive_threshold=config.metrics.threshold.adaptive,
default_image_threshold=config.metrics.threshold.image_default,
Expand All @@ -88,3 +86,39 @@ def train(
self.save_model(output_model)

logger.info("Training completed.")

def load_model(self, otx_model: Optional[ModelEntity]) -> AnomalyModule:
"""Create and Load Anomalib Module from OTE Model.
This method checks if the task environment has a saved OTE Model,
and creates one. If the OTE model already exists, it returns the
the model with the saved weights.
Args:
otx_model (Optional[ModelEntity]): OTE Model from the
task environment.
Returns:
AnomalyModule: Anomalib
classification or segmentation model with/without weights.
"""
model = get_model(config=self.config)
if otx_model is None:
logger.info(
"No trained model in project yet. Created new model with '%s'",
self.model_name,
)
else:
buffer = io.BytesIO(otx_model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))

try:
if model_data["config"]["model"]["backbone"] == self.config["model"]["backbone"]:
model.load_state_dict(model_data["model"])
logger.info("Loaded model weights from Task Environment")
else:
logger.info(
"Model backbone does not match. Created new model with '%s'",
self.model_name,
)
except BaseException as exception:
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception

return model
2 changes: 1 addition & 1 deletion otx/api/utils/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Union
from typing import Iterable, Union

import cv2
import numpy as np
Expand Down