Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vsaltykovx/add mmdetection input parameters validation 2 #1112

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0b8cbf1
added input parameters validation and tests for mmdet/apis/ote
saltykox Feb 15, 2022
f9d96a7
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 15, 2022
af5b029
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 17, 2022
5a34e7f
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 17, 2022
44a6c2b
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 18, 2022
1cc4410
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 18, 2022
23207fd
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 21, 2022
adc9621
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 21, 2022
d1c521c
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 22, 2022
51c0a38
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 24, 2022
ab81a36
moved load_test_dataset function to tests\parameters_validation\valid…
saltykox Feb 25, 2022
30fbec2
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 25, 2022
7071e87
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Feb 28, 2022
eb83c34
added log messages to raise_value_error_if_parameter_has_unexpected_type
saltykox Feb 28, 2022
52fd656
added log messages to raise_value_error_if_parameter_has_unexpected_type
saltykox Feb 28, 2022
3fb9680
updated logger
saltykox Feb 28, 2022
ef9ffce
removed additional log messages
saltykox Mar 1, 2022
6919ef8
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Mar 16, 2022
1d5e06c
added check_input_parameters_type decorator
saltykox Mar 18, 2022
04e5fb8
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Mar 18, 2022
e26fd37
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Mar 21, 2022
b5f6621
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Mar 21, 2022
5ba02cc
moved mmdetection params validation tests to training_extensions
saltykox Mar 21, 2022
3783f50
fixed test_load_annotation_from_ote_dataset_call_params_validation
saltykox Mar 21, 2022
8735fbb
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Mar 22, 2022
3e33ba6
added input parameters validation in mmdet/apis/ote/apis/detection/ m…
saltykox Mar 22, 2022
4f55cae
updated expected types
saltykox Mar 22, 2022
4129bc4
fix type for weight_file
saltykox Mar 22, 2022
94c496a
optimized imports
saltykox Mar 23, 2022
cc77164
Merge remote-tracking branch 'origin/vsaltykovx/add_mmdetection_input…
saltykox Mar 23, 2022
ede7332
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Mar 25, 2022
fb0fd0e
updated function in ote_sdk/ote_sdk/utils/argument_checks.py
saltykox Mar 25, 2022
cd31623
fixed expected type in JsonFilePathCheck
saltykox Mar 25, 2022
31e179c
updated check_nested_classes_parameters function
saltykox Mar 28, 2022
6a7df1c
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
saltykox Apr 5, 2022
c99aebd
added tests to cover get_data_cfg function and StopLossNanTrainingHoo…
saltykox Apr 5, 2022
3ac55ba
updated check_input_parameters_type
saltykox Apr 5, 2022
808a0ed
Merge branch 'develop' into vsaltykovx/add_mmdetection_input_paramete…
hlewando May 17, 2022
95440ea
Merge remote-tracking branch 'origin/develop' into vsaltykovx/add_mmd…
May 18, 2022
bdb7599
Corrected config input parameter type.
sstrehlk May 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
import os
import tempfile
from collections import defaultdict
from typing import List, Optional
from typing import List, Optional, Union

import torch
from mmcv import Config, ConfigDict
from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.label import LabelEntity, Domain
from ote_sdk.usecases.reporting.time_monitor_callback import TimeMonitorCallback
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
DirectoryPathCheck,
check_input_parameters_type
)

from detection_tasks.extension.datasets.data_utils import get_anchor_boxes, \
get_sizes_from_dataset_entity, format_list_to_str
Expand All @@ -43,14 +48,16 @@
logger = get_root_logger()


@check_input_parameters_type()
def is_epoch_based_runner(runner_config: ConfigDict):
return 'Epoch' in runner_config.type


@check_input_parameters_type({"work_dir": DirectoryPathCheck})
def patch_config(config: Config, work_dir: str, labels: List[LabelEntity], domain: Domain, random_seed: Optional[int] = None):
# Set runner if not defined.
if 'runner' not in config:
config.runner = {'type': 'EpochBasedRunner'}
config.runner = ConfigDict({'type': 'EpochBasedRunner'})

# Check that there is no conflict in specification of number of training epochs.
# Move global definition of epochs inside runner config.
Expand Down Expand Up @@ -112,6 +119,7 @@ def patch_config(config: Config, work_dir: str, labels: List[LabelEntity], domai
config.seed = random_seed


@check_input_parameters_type()
def set_hyperparams(config: Config, hyperparams: OTEDetectionConfig):
config.optimizer.lr = float(hyperparams.learning_parameters.learning_rate)
config.lr_config.warmup_iters = int(hyperparams.learning_parameters.learning_rate_warmup_iters)
Expand All @@ -126,7 +134,8 @@ def set_hyperparams(config: Config, hyperparams: OTEDetectionConfig):
config.runner.max_iters = total_iterations


def patch_adaptive_repeat_dataset(config: Config, num_samples: int,
@check_input_parameters_type()
def patch_adaptive_repeat_dataset(config: Union[Config, ConfigDict], num_samples: int,
decay: float = -0.002, factor: float = 30):
""" Patch the repeat times and training epochs adatively

Expand Down Expand Up @@ -155,14 +164,17 @@ def patch_adaptive_repeat_dataset(config: Config, num_samples: int,
data_train.times = new_repeat


def prepare_for_testing(config: Config, dataset: DatasetEntity) -> Config:
@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def prepare_for_testing(config: Union[Config, ConfigDict], dataset: DatasetEntity) -> Config:
config = copy.deepcopy(config)
# FIXME. Should working directories be modified here?
config.data.test.ote_dataset = dataset
return config


def prepare_for_training(config: Config, train_dataset: DatasetEntity, val_dataset: DatasetEntity,
@check_input_parameters_type({"train_dataset": DatasetParamTypeCheck,
"val_dataset": DatasetParamTypeCheck})
def prepare_for_training(config: Union[Config, ConfigDict], train_dataset: DatasetEntity, val_dataset: DatasetEntity,
time_monitor: TimeMonitorCallback, learning_curves: defaultdict) -> Config:
config = copy.deepcopy(config)
prepare_work_dir(config)
Expand All @@ -175,7 +187,8 @@ def prepare_for_training(config: Config, train_dataset: DatasetEntity, val_datas
return config


def config_to_string(config: Config) -> str:
@check_input_parameters_type()
def config_to_string(config: Union[Config, ConfigDict]) -> str:
"""
Convert a full mmdetection config to a string.

Expand All @@ -194,6 +207,7 @@ def config_to_string(config: Config) -> str:
return Config(config_copy).pretty_text


@check_input_parameters_type()
def config_from_string(config_string: str) -> Config:
"""
Generate an mmdetection config dict object from a string.
Expand All @@ -207,6 +221,7 @@ def config_from_string(config_string: str) -> Config:
return Config.fromfile(temp_file.name)


@check_input_parameters_type()
def save_config_to_file(config: Config):
""" Dump the full config to a file. Filename is 'config.py', it is saved in the current work_dir. """
filepath = os.path.join(config.work_dir, 'config.py')
Expand All @@ -215,7 +230,8 @@ def save_config_to_file(config: Config):
f.write(config_string)


def prepare_work_dir(config: Config) -> str:
@check_input_parameters_type()
def prepare_work_dir(config: Union[Config, ConfigDict]) -> str:
base_work_dir = config.work_dir
checkpoint_dirs = glob.glob(os.path.join(base_work_dir, "checkpoints_round_*"))
train_round_checkpoint_dir = os.path.join(base_work_dir, f"checkpoints_round_{len(checkpoint_dirs)}")
Expand All @@ -230,6 +246,7 @@ def prepare_work_dir(config: Config) -> str:
return train_round_checkpoint_dir


@check_input_parameters_type()
def set_data_classes(config: Config, labels: List[LabelEntity]):
# Save labels in data configs.
for subset in ('train', 'val', 'test'):
Expand All @@ -256,7 +273,8 @@ def set_data_classes(config: Config, labels: List[LabelEntity]):
# self.config.model.CLASSES = label_names


def patch_datasets(config: Config, domain):
@check_input_parameters_type()
def patch_datasets(config: Config, domain: Domain):

def patch_color_conversion(pipeline):
# Default data format for OTE is RGB, while mmdet uses BGR, so negate the color conversion flag.
Expand Down Expand Up @@ -289,7 +307,8 @@ def patch_color_conversion(pipeline):
patch_color_conversion(cfg.pipeline)


def remove_from_config(config, key: str):
@check_input_parameters_type()
def remove_from_config(config: Union[Config, ConfigDict], key: str):
if key in config:
if isinstance(config, Config):
del config._cfg_dict[key]
Expand All @@ -298,6 +317,8 @@ def remove_from_config(config, key: str):
else:
raise ValueError(f'Unknown config type {type(config)}')


@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector):
if not kmeans_import:
raise ImportError('Sklearn package is not installed. To enable anchor boxes clustering, please install '
Expand All @@ -308,7 +329,7 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
if transforms.type == 'MultiScaleFlipAug']
prev_generator = config.model.bbox_head.anchor_generator
group_as = [len(width) for width in prev_generator.widths]
wh_stats = get_sizes_from_dataset_entity(dataset, target_wh)
wh_stats = get_sizes_from_dataset_entity(dataset, list(target_wh))

if len(wh_stats) < sum(group_as):
logger.warning(f'There are not enough objects to cluster: {len(wh_stats)} were detected, while it should be '
Expand All @@ -332,7 +353,8 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
return config, model


def get_data_cfg(config: Config, subset: str = 'train') -> Config:
@check_input_parameters_type()
def get_data_cfg(config: Union[Config, ConfigDict], subset: str = 'train') -> Config:
data_cfg = config.data[subset]
while 'dataset' in data_cfg:
data_cfg = data_cfg.dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
from ote_sdk.usecases.tasks.interfaces.inference_interface import IInferenceTask
from ote_sdk.usecases.tasks.interfaces.unload_interface import IUnload
from ote_sdk.serialization.label_mapper import label_schema_to_bytes
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)

from mmdet.apis import export_model
from detection_tasks.apis.detection.config_utils import patch_config, prepare_for_testing, set_hyperparams
Expand All @@ -63,6 +67,7 @@ class OTEDetectionInferenceTask(IInferenceTask, IExportTask, IEvaluationTask, IU

_task_environment: TaskEnvironment

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
""""
Task for inference object detection models using OTEDetection.
Expand Down Expand Up @@ -239,6 +244,7 @@ def _add_predictions_to_dataset(self, prediction_results, dataset, confidence_th
dataset_item.append_metadata_item(active_score, model=self._task_environment.model)


@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None) -> DatasetEntity:
""" Analyzes a dataset using the latest inference model. """

Expand Down Expand Up @@ -330,7 +336,7 @@ def dummy_dump_features_hook(mod, inp, out):
eval_predictions = zip(eval_predictions, feature_vectors)
return eval_predictions, metric


@check_input_parameters_type()
def evaluate(self,
output_result_set: ResultSetEntity,
evaluation_metric: Optional[str] = None):
Expand Down Expand Up @@ -375,6 +381,7 @@ def unload(self):
logger.warning(f"Done unloading. "
f"Torch is still occupying {torch.cuda.memory_allocated()} bytes of GPU memory")

@check_input_parameters_type()
def export(self,
export_type: ExportType,
output_model: ModelEntity):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
from ote_sdk.usecases.tasks.interfaces.export_interface import ExportType
from ote_sdk.usecases.tasks.interfaces.optimization_interface import IOptimizationTask
from ote_sdk.usecases.tasks.interfaces.optimization_interface import OptimizationType
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)

from mmdet.apis import train_detector
from mmdet.apis.fake_input import get_fake_input
Expand All @@ -59,6 +63,7 @@

class OTEDetectionNNCFTask(OTEDetectionInferenceTask, IOptimizationTask):

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
""""
Task for compressing object detection models using NNCF.
Expand Down Expand Up @@ -177,12 +182,13 @@ def _create_compressed_model(self, dataset, config):
get_fake_input_func=get_fake_input,
is_accuracy_aware=is_acc_aware_training_set)

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def optimize(
self,
optimization_type: OptimizationType,
dataset: DatasetEntity,
output_model: ModelEntity,
optimization_parameters: Optional[OptimizationParameters],
optimization_parameters: Optional[OptimizationParameters] = None,
):
if optimization_type is not OptimizationType.NNCF:
raise RuntimeError("NNCF is the only supported optimization")
Expand Down Expand Up @@ -247,6 +253,7 @@ def optimize(

self._is_training = False

@check_input_parameters_type()
def export(self, export_type: ExportType, output_model: ModelEntity):
if self._compression_ctrl is None:
super().export(export_type, output_model)
Expand All @@ -256,6 +263,7 @@ def export(self, export_type: ExportType, output_model: ModelEntity):
super().export(export_type, output_model)
self._model.enable_dynamic_graph_building()

@check_input_parameters_type()
def save_model(self, output_model: ModelEntity):
buffer = io.BytesIO()
hyperparams = self._task_environment.get_hyper_parameters(OTEDetectionConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,19 @@
from ote_sdk.usecases.exportable_code.inference import BaseInferencer
from ote_sdk.usecases.exportable_code.prediction_to_annotation_converter import (
DetectionBoxToAnnotationConverter,
IPredictionToAnnotationConverter,
MaskToAnnotationConverter,
RotatedRectToAnnotationConverter,
)
from ote_sdk.usecases.tasks.interfaces.deployment_interface import IDeploymentTask
from ote_sdk.usecases.tasks.interfaces.evaluate_interface import IEvaluationTask
from ote_sdk.usecases.tasks.interfaces.inference_interface import IInferenceTask
from ote_sdk.usecases.tasks.interfaces.optimization_interface import IOptimizationTask, OptimizationType
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)
from shutil import copyfile, copytree
from typing import Any, Dict, List, Optional, Tuple, Union
from zipfile import ZipFile

Expand All @@ -66,24 +72,29 @@

class BaseInferencerWithConverter(BaseInferencer):

def __init__(self, configuration, model, converter) -> None:
@check_input_parameters_type()
def __init__(self, configuration: dict, model: Model, converter: IPredictionToAnnotationConverter) -> None:
self.configuration = configuration
self.model = model
self.converter = converter

@check_input_parameters_type()
def pre_process(self, image: np.ndarray) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
return self.model.preprocess(image)

@check_input_parameters_type()
def post_process(self, prediction: Dict[str, np.ndarray], metadata: Dict[str, Any]) -> AnnotationSceneEntity:
detections = self.model.postprocess(prediction, metadata)

return self.converter.convert_to_annotation(detections, metadata)

@check_input_parameters_type()
def forward(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
return self.model.infer_sync(inputs)


class OpenVINODetectionInferencer(BaseInferencerWithConverter):
@check_input_parameters_type()
def __init__(
self,
hparams: OTEDetectionConfig,
Expand Down Expand Up @@ -115,6 +126,7 @@ def __init__(


class OpenVINOMaskInferencer(BaseInferencerWithConverter):
@check_input_parameters_type()
def __init__(
self,
hparams: OTEDetectionConfig,
Expand Down Expand Up @@ -149,6 +161,7 @@ def __init__(


class OpenVINORotatedRectInferencer(BaseInferencerWithConverter):
@check_input_parameters_type()
def __init__(
self,
hparams: OTEDetectionConfig,
Expand Down Expand Up @@ -183,11 +196,13 @@ def __init__(


class OTEOpenVinoDataLoader(DataLoader):
@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def __init__(self, dataset: DatasetEntity, inferencer: BaseInferencer):
self.dataset = dataset
self.inferencer = inferencer

def __getitem__(self, index):
@check_input_parameters_type()
def __getitem__(self, index: int):
image = self.dataset[index].numpy
annotation = self.dataset[index].annotation_scene
inputs, metadata = self.inferencer.pre_process(image)
Expand All @@ -199,6 +214,7 @@ def __len__(self):


class OpenVINODetectionTask(IDeploymentTask, IInferenceTask, IEvaluationTask, IOptimizationTask):
@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
logger.info('Loading OpenVINO OTEDetectionTask')
self.task_environment = task_environment
Expand Down Expand Up @@ -230,6 +246,7 @@ def load_inferencer(self) -> Union[OpenVINODetectionInferencer, OpenVINOMaskInfe
return OpenVINORotatedRectInferencer(*args)
raise RuntimeError(f"Unknown OpenVINO Inferencer TaskType: {self.task_type}")

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None) -> DatasetEntity:
logger.info('Start OpenVINO inference')
update_progress_callback = default_progress_callback
Expand All @@ -243,6 +260,7 @@ def infer(self, dataset: DatasetEntity, inference_parameters: Optional[Inference
logger.info('OpenVINO inference completed')
return dataset

@check_input_parameters_type()
def evaluate(self,
output_result_set: ResultSetEntity,
evaluation_metric: Optional[str] = None):
Expand All @@ -252,6 +270,7 @@ def evaluate(self,
output_result_set.performance = MetricsHelper.compute_f_measure(output_result_set).get_performance()
logger.info('OpenVINO metric evaluation completed')

@check_input_parameters_type()
def deploy(self,
output_model: ModelEntity) -> None:
logger.info('Deploying the model')
Expand Down Expand Up @@ -279,11 +298,12 @@ def deploy(self,
output_model.exportable_code = zip_buffer.getvalue()
logger.info('Deploying completed')

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def optimize(self,
optimization_type: OptimizationType,
dataset: DatasetEntity,
output_model: ModelEntity,
optimization_parameters: Optional[OptimizationParameters]):
optimization_parameters: Optional[OptimizationParameters] = None):
logger.info('Start POT optimization')

if optimization_type is not OptimizationType.POT:
Expand Down
Loading