Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Mar 23, 2022
1 parent 114a046 commit 8d77b50
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ nncf_optimization:
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: [ ]
rules: []
type: UI_RULES
value: false
visible_in_ui: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ nncf_optimization:
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: [ ]
rules: []
type: UI_RULES
value: false
visible_in_ui: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ nncf_optimization:
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: [ ]
rules: []
type: UI_RULES
value: false
visible_in_ui: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ nncf_optimization:
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: [ ]
rules: []
type: UI_RULES
value: false
visible_in_ui: false
Expand Down
24 changes: 15 additions & 9 deletions external/anomaly/ote_anomalib/configs/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@
from sys import maxsize

from attr import attrs
from ote_anomalib.configs.configuration_enums import POTQuantizationPreset
from ote_sdk.configuration import ConfigurableParameters
from ote_sdk.configuration.elements import (ParameterGroup,
add_parameter_group,
boolean_attribute,
configurable_boolean,
configurable_integer, selectable,
string_attribute)
from ote_sdk.configuration.elements import (
ParameterGroup,
add_parameter_group,
boolean_attribute,
configurable_boolean,
configurable_integer,
selectable,
string_attribute,
)
from ote_sdk.configuration.model_lifecycle import ModelLifecycle

from ote_anomalib.configs.configuration_enums import POTQuantizationPreset


@attrs
class BaseAnomalyConfig(ConfigurableParameters):
Expand Down Expand Up @@ -98,6 +100,10 @@ class POTParameters(ParameterGroup):

@attrs
class NNCFOptimization(ParameterGroup):
"""
Parameters for NNCF optimization
"""

header = string_attribute("Optimization by NNCF")
description = header

Expand All @@ -117,7 +123,7 @@ class NNCFOptimization(ParameterGroup):
default_value=False,
header="Whether filter pruning is supported",
description="Whether filter pruning is supported",
affects_outcome_of=ModelLifecycle.TRAINING
affects_outcome_of=ModelLifecycle.TRAINING,
)

dataset = add_parameter_group(DatasetParameters)
Expand Down
77 changes: 34 additions & 43 deletions external/anomaly/ote_anomalib/nncf_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
import io
import json
import os
import subprocess # nosec
from glob import glob
from typing import Dict, Optional
from typing import Dict, Optional, Union

import torch
from anomalib.models import AnomalyModule, get_model
Expand All @@ -30,23 +28,20 @@
is_state_nncf,
wrap_nncf_model,
)
from nncf.api.compression import CompressionAlgorithmController
from ote_anomalib import AnomalyInferenceTask
from ote_anomalib.callbacks import ProgressCallback
from ote_anomalib.data import OTEAnomalyDataModule
from ote_anomalib.logging import get_logger
from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.metrics import Performance, ScoreMetric
from ote_sdk.entities.model import (
ModelEntity,
ModelFormat,
ModelOptimizationType,
ModelPrecision,
OptimizationMethod,
)
from ote_sdk.entities.optimization_parameters import OptimizationParameters
from ote_sdk.entities.task_environment import TaskEnvironment
from ote_sdk.serialization.label_mapper import label_schema_to_bytes
from ote_sdk.usecases.tasks.interfaces.export_interface import ExportType
from ote_sdk.usecases.tasks.interfaces.optimization_interface import (
IOptimizationTask,
OptimizationType,
Expand All @@ -65,8 +60,7 @@ def __init__(self, task_environment: TaskEnvironment) -> None:
Args:
task_environment (TaskEnvironment): OTE Task environment.
"""
self.val_dataloader = None
self.compression_ctrl = None
self.compression_ctrl: Union[CompressionAlgorithmController, None] = None
self.nncf_preset = "nncf_quantization"
super().__init__(task_environment)
self.optimization_type = ModelOptimizationType.NNCF
Expand All @@ -76,7 +70,10 @@ def _set_attributes_by_hyperparams(self):
pruning = self.hyper_parameters.nncf_optimization.enable_pruning
if quantization and pruning:
self.nncf_preset = "nncf_quantization_pruning"
self.optimization_methods = [OptimizationMethod.QUANTIZATION, OptimizationMethod.FILTER_PRUNING]
self.optimization_methods = [
OptimizationMethod.QUANTIZATION,
OptimizationMethod.FILTER_PRUNING,
]
self.precision = [ModelPrecision.INT8]
return
if quantization and not pruning:
Expand Down Expand Up @@ -115,33 +112,33 @@ def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
self.optimization_config = compose_nncf_config(common_nncf_config, [self.nncf_preset])
self.config.merge_with(self.optimization_config)
model = get_model(config=self.config)
if ote_model is None:
if ote_model is not None:
raise ValueError("No trained model in project. NNCF require pretrained weights to compress the model")

buffer = io.BytesIO(ote_model.get_data("weights.pth")) # type: ignore
model_data = torch.load(buffer, map_location=torch.device("cpu"))

if is_state_nncf(model_data):
logger.info("Loaded model weights from Task Environment and wrapped by NNCF")

# Workaround to fix incorrect loading state for wrapped pytorch_lighting model
new_model = dict()
for key in model_data["model"].keys():
if key.startswith("model."):
new_model[key.replace("model.", "")] = model_data["model"][key]
model_data["model"] = new_model

self.compression_ctrl, model.model = wrap_nncf_model(
model.model,
self.optimization_config["nncf_config"],
init_state_dict=model_data,
)
else:
buffer = io.BytesIO(ote_model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))

if is_state_nncf(model_data):
logger.info("Loaded model weights from Task Environment and wrapped by NNCF")

# Workaround to fix incorrect loading state for wrapped pytorch_lighting model
new_model = dict()
for key in model_data["model"].keys():
if key.startswith("model."):
new_model[key.replace("model.", "")] = model_data["model"][key]
model_data["model"] = new_model

self.compression_ctrl, model.model = wrap_nncf_model(
model.model, self.optimization_config["nncf_config"], init_state_dict=model_data
)
else:
try:
model.load_state_dict(model_data["model"])
logger.info("Loaded model weights from Task Environment")
except BaseException as exception:
raise ValueError(
"Could not load the saved model. The model file structure is invalid."
) from exception
try:
model.load_state_dict(model_data["model"])
logger.info("Loaded model weights from Task Environment")
except BaseException as exception:
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception

return model

Expand All @@ -165,13 +162,7 @@ def optimize(
if optimization_type is not OptimizationType.NNCF:
raise RuntimeError("NNCF is the only supported optimization")

# config = self.get_config()
# logger.info("Training Configs '%s'", config)

datamodule = OTEAnomalyDataModule(config=self.config, dataset=dataset, task_type=self.task_type)
# Setup dataset to initialization of compressed model
# datamodule.setup(stage="fit")
# nncf_config = yaml.safe_load(OmegaConf.to_yaml(self.config['nncf_config']))

nncf_callback = NNCFCallback(nncf_config=self.optimization_config["nncf_config"])
callbacks = [
Expand All @@ -195,7 +186,7 @@ def _model_info(self) -> Dict:
"""

return {
"compression_state": self.compression_ctrl.get_compression_state(),
"compression_state": self.compression_ctrl.get_compression_state(), # type: ignore
"meta": {
"config": self.config,
"nncf_enable_compression": True,
Expand All @@ -211,4 +202,4 @@ def _export_to_onnx(self, onnx_path: str):
Args:
onnx_path (str): path to save ONNX file
"""
self.compression_ctrl.export_model(onnx_path, "onnx_11")
self.compression_ctrl.export_model(onnx_path, "onnx_11") # type: ignore
7 changes: 1 addition & 6 deletions external/anomaly/ote_anomalib/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,14 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import io

import torch
from anomalib.utils.callbacks import MinMaxNormalizationCallback
from ote_anomalib import AnomalyInferenceTask
from ote_anomalib.callbacks import ProgressCallback
from ote_anomalib.data import OTEAnomalyDataModule
from ote_anomalib.logging import get_logger
from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.metrics import Performance, ScoreMetric
from ote_sdk.entities.model import ModelEntity, ModelPrecision
from ote_sdk.entities.model import ModelEntity
from ote_sdk.entities.train_parameters import TrainParameters
from ote_sdk.serialization.label_mapper import label_schema_to_bytes
from ote_sdk.usecases.tasks.interfaces.training_interface import ITrainingTask
from pytorch_lightning import Trainer

Expand Down

0 comments on commit 8d77b50

Please sign in to comment.