Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 02d724a
Merge: 56e6624 3ba3272
Author: Songki Choi <songki.choi@intel.com>
Date:   Tue May 31 16:00:39 2022 +0900

    Merge pull request #1113 from openvinotoolkit/da/anomaly-exportable-code-fix

    [ANOMALY] Fix: exportable code for anomaly tasks

commit 56e6624
Merge: 2e18117 9ef77d0
Author: Songki Choi <songki.choi@intel.com>
Date:   Tue May 31 15:45:59 2022 +0900

    Merge pull request #1118 from openvinotoolkit/ashwin/fix_non_deterministic

    [Anomaly Task] Fix non deterministic + sample.py

commit 2e18117
Merge: c240c4b 16f2138
Author: Songki Choi <songki.choi@intel.com>
Date:   Tue May 31 15:40:43 2022 +0900

    Merge pull request #1120 from openvinotoolkit/da/use-is-anomalous

    [ANOMALY] Use is_anomalous attribute instead of string matching

commit c240c4b
Merge: f003a20 46bf1bc
Author: Songki Choi <songki.choi@intel.com>
Date:   Tue May 31 15:33:42 2022 +0900

    Merge pull request #1117 from wonjuleee/develop

    nncf versionn upgrade

commit 16f2138
Author: Dick Ameln <dick.ameln@intel.com>
Date:   Mon May 30 11:35:18 2022 +0200

    add is_anomalous attribute to labels in test cases

commit f003a20
Merge: 38a4d88 bdb7599
Author: Songki Choi <songki.choi@intel.com>
Date:   Wed May 25 15:39:02 2022 +0900

    Merge pull request #1112 from openvinotoolkit/vsaltykovx/add_mmdetection_input_parameters_validation_2

    Vsaltykovx/add mmdetection input parameters validation 2

commit 2aab8df
Author: Dick Ameln <dick.ameln@intel.com>
Date:   Tue May 24 17:56:52 2022 +0200

    use is_anomalous attribute instead of string matching

commit bdb7599
Author: Slawomir Strehlke <slawomir.strehlke@intel.com>
Date:   Tue May 24 10:21:37 2022 +0200

    Corrected config input parameter type.

commit 9ef77d0
Author: Ashwin Vaidya <ashwin.vaidya@intel.com>
Date:   Mon May 23 16:01:37 2022 +0200

    Set optional params to None

commit fdd5307
Author: Ashwin Vaidya <ashwin.vaidya@intel.com>
Date:   Mon May 23 15:39:37 2022 +0200

    spacing

commit b7c30a2
Author: Ashwin Vaidya <ashwin.vaidya@intel.com>
Date:   Mon May 23 15:38:59 2022 +0200

    Add seed + sample.py guide

commit 46bf1bc
Author: Wonju Lee <wonju.lee@intel.com>
Date:   Mon May 23 19:10:47 2022 +0900

    nncf versionn upgrade

commit 38a4d88
Merge: c49a6fd cf33dcb
Author: Songki Choi <songki.choi@intel.com>
Date:   Mon May 23 18:20:43 2022 +0900

    Merge pull request #1114 from A-Artemis/develop

    Added a getter and a setter for __metadata

commit cf33dcb
Author: Aurelien <aurelien.adriaenssens@intel.com>
Date:   Mon May 23 10:20:33 2022 +0200

    Corrected the usage of .metadata or .get_metadata() in certain tests. In test_metadata.py I have updated the assertion to "typing.Optional[str]".

commit fcf4925
Author: Aurelien <aurelien.adriaenssens@intel.com>
Date:   Thu May 19 16:42:18 2022 +0200

    Added a getter and a setter for __metadata of DatasetItemEntity. Updated the tests as well.

commit 3ba3272
Author: Dick Ameln <dick.ameln@intel.com>
Date:   Wed May 18 17:08:17 2022 +0200

    use standard visualizer instead of anomaly visualizer for anomaly tasks

commit 95440ea
Merge: 808a0ed c49a6fd
Author: hlewando <Killer_21212>
Date:   Wed May 18 13:10:30 2022 +0200

    Merge remote-tracking branch 'origin/develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 808a0ed
Merge: 3ac55ba c54250b
Author: Hubert Lewandowski <hubert.lewandowski@intel.com>
Date:   Tue May 17 08:56:58 2022 +0200

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 3ac55ba
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Apr 5 13:58:01 2022 +0300

    updated check_input_parameters_type

commit c99aebd
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Apr 5 13:46:33 2022 +0300

    added tests to cover get_data_cfg function and StopLossNanTrainingHook after_train_iter method

commit 6a7df1c
Merge: 31e179c 33d11e6
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Apr 5 13:02:21 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

    # Conflicts:
    #	external/mmdetection/detection_tasks/apis/detection/openvino_task.py

commit 31e179c
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Mar 28 10:21:05 2022 +0300

    updated check_nested_classes_parameters function

commit cd31623
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Fri Mar 25 13:44:21 2022 +0300

    fixed expected type in JsonFilePathCheck

commit fb0fd0e
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Fri Mar 25 13:34:10 2022 +0300

    updated function in ote_sdk/ote_sdk/utils/argument_checks.py

commit ede7332
Merge: cc77164 0d4ce2c
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Fri Mar 25 11:55:08 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit cc77164
Merge: 94c496a 4129bc4
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Wed Mar 23 10:01:59 2022 +0300

    Merge remote-tracking branch 'origin/vsaltykovx/add_mmdetection_input_parameters_validation' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 94c496a
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Wed Mar 23 10:01:36 2022 +0300

    optimized imports

commit 4129bc4
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Mar 22 17:15:12 2022 +0300

    fix type for weight_file

commit 4f55cae
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Mar 22 15:11:00 2022 +0300

    updated expected types

commit 3e33ba6
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Mar 22 14:09:50 2022 +0300

    added input parameters validation in mmdet/apis/ote/apis/detection/ methods and functions

commit 8735fbb
Merge: 3783f50 1833b6c
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Mar 22 11:19:24 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 3783f50
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Mar 21 17:07:23 2022 +0300

    fixed test_load_annotation_from_ote_dataset_call_params_validation

commit 5ba02cc
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Mar 21 16:16:05 2022 +0300

    moved mmdetection params validation tests to training_extensions

commit b5f6621
Merge: e26fd37 8d3e4ab
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Mar 21 16:07:11 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit e26fd37
Merge: 04e5fb8 5931ec7
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Mar 21 16:05:36 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 04e5fb8
Merge: 1d5e06c abafaa3
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Fri Mar 18 13:55:12 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 1d5e06c
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Fri Mar 18 12:07:27 2022 +0300

    added check_input_parameters_type decorator

commit 6919ef8
Merge: ef9ffce a4b4263
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Wed Mar 16 11:55:58 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit ef9ffce
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Mar 1 09:03:26 2022 +0300

    removed additional log messages

commit 3fb9680
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Feb 28 17:05:51 2022 +0300

    updated logger

commit 52fd656
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Feb 28 13:19:52 2022 +0300

    added log messages to raise_value_error_if_parameter_has_unexpected_type

commit eb83c34
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Feb 28 13:18:40 2022 +0300

    added log messages to raise_value_error_if_parameter_has_unexpected_type

commit 7071e87
Merge: 30fbec2 a2ae0e0
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Feb 28 11:12:09 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 30fbec2
Merge: ab81a36 6b3302e
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Fri Feb 25 18:24:51 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit ab81a36
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Fri Feb 25 17:24:05 2022 +0300

    moved load_test_dataset function to tests\parameters_validation\validation_helper.py

commit 51c0a38
Merge: d1c521c 719e80a
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Thu Feb 24 09:17:49 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit d1c521c
Merge: adc9621 c1c3753
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Feb 22 09:16:13 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit adc9621
Merge: 23207fd e8d24d0
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Feb 21 09:17:14 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 23207fd
Merge: 1cc4410 70bd630
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Mon Feb 21 08:19:36 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 1cc4410
Merge: 44a6c2b 9d2503b
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Fri Feb 18 15:08:41 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 44a6c2b
Merge: 5a34e7f d99df37
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Fri Feb 18 09:04:48 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 5a34e7f
Merge: af5b029 533ba44
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Thu Feb 17 13:05:52 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit af5b029
Merge: f9d96a7 3ebf8e9
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Thu Feb 17 09:00:44 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit f9d96a7
Merge: 0b8cbf1 461d501
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Feb 15 16:56:36 2022 +0300

    Merge branch 'develop' into vsaltykovx/add_mmdetection_input_parameters_validation

commit 0b8cbf1
Author: saltykox <valeriyx.saltykov@intel.com>
Date:   Tue Feb 15 15:20:17 2022 +0300

    added input parameters validation and tests for mmdet/apis/ote
  • Loading branch information
sstrehlk committed May 31, 2022
1 parent f95213c commit 474317e
Show file tree
Hide file tree
Showing 40 changed files with 3,422 additions and 154 deletions.
2 changes: 1 addition & 1 deletion external/anomaly/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ kornia==0.5.6
lxml==4.6.5
matplotlib==3.4.3
networkx~=2.5
nncf@ git+https://github.com/openvinotoolkit/nncf@37a830a412e60ec2fd2d84d7f00e2524e5f62777#egg=nncf
nncf==2.2.0
numpy==1.19.5
omegaconf==2.1.1
onnx==1.10.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def create_task_annotations(task: str, data_path: str, annotation_path: str) ->
Raises:
ValueError: When task is not classification, detection or segmentation.
"""
annotation_path = os.path.join(data_path, task)
annotation_path = os.path.join(annotation_path, task)
os.makedirs(annotation_path, exist_ok=True)

for split in ["train", "val", "test"]:
Expand Down
13 changes: 12 additions & 1 deletion external/anomaly/ote_anomalib/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Optional

from anomalib.utils.callbacks import MinMaxNormalizationCallback
from ote_anomalib import AnomalyInferenceTask
from ote_anomalib.callbacks import ProgressCallback
Expand All @@ -23,7 +25,7 @@
from ote_sdk.entities.model import ModelEntity
from ote_sdk.entities.train_parameters import TrainParameters
from ote_sdk.usecases.tasks.interfaces.training_interface import ITrainingTask
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, seed_everything

logger = get_logger(__name__)

Expand All @@ -36,17 +38,26 @@ def train(
dataset: DatasetEntity,
output_model: ModelEntity,
train_parameters: TrainParameters,
seed: Optional[int] = None,
) -> None:
"""Train the anomaly classification model.
Args:
dataset (DatasetEntity): Input dataset.
output_model (ModelEntity): Output model to save the model weights.
train_parameters (TrainParameters): Training parameters
seed: (Optional[int]): Setting seed to a value other than 0 also marks PytorchLightning trainer's
deterministic flag to True.
"""
logger.info("Training the model.")

config = self.get_config()

if seed:
logger.info(f"Setting seed to {seed}")
seed_everything(seed, workers=True)
config.trainer.deterministic = True

logger.info("Training Configs '%s'", config)

datamodule = OTEAnomalyDataModule(config=config, dataset=dataset, task_type=self.task_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def _run_ote_training(self, data_collector):
self.copy_hyperparams = deepcopy(self.task.task_environment.get_hyper_parameters())

try:
self.task.train(self.dataset, self.output_model, TrainParameters)
# fix seed so that result is repeatable
self.task.train(self.dataset, self.output_model, TrainParameters, seed=42)
except Exception as ex:
raise RuntimeError("Training failed") from ex

Expand Down
23 changes: 23 additions & 0 deletions external/anomaly/tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
OpenVINO Training Extension interacts with the anomaly detection library ([Anomalib](https://github.com/openvinotoolkit/anomalib)) by providing interfaces in the `external/anomaly` of this repository. The `sample.py` file contained in this folder serves as an end-to-end example of how these interfaces are used. To begin using this script, first ensure that `ote_cli`, `ote_sdk` and `external/anomaly` dependencies are installed.

To get started, we provide a handy script in `ote_anomalib/data/create_mvtec_ad_json_annotations.py` to help generate annotation json files for MVTec dataset. Assuming that you have placed the MVTec dataset in a directory your home folder (`~/dataset/MVTec`), you can run the following command to generate the annotations.

```bash
python create_mvtec_ad_json_annotations.py --data_path ~/datasets/MVTec --annotation_path ~/training_extensions/data/MVtec/
```

This will generate three folders in `~/training_extensions/data/MVtec/` for classification, segmentation and detection task.

Then, to run sample.py you can use the following command.

```bash
python tools/sample.py \
--dataset_path ~/datasets/MVTec \
--category bottle \
--train-ann-files ../../data/MVtec/bottle/segmentation/train.json \
--val-ann-files ../../data/MVtec/bottle/segmentation/val.json \
--test-ann-files ../../data/MVtec/bottle/segmentation/test.json \
--model_template_path ./configs/anomaly_segmentation/padim/template.yaml
```

Optionally, you can also optimize to `nncf` or `pot` by using the `--optimization` flag
22 changes: 15 additions & 7 deletions external/anomaly/tools/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import shutil
from argparse import Namespace
from typing import Any, Dict, Type, Union
from typing import Any, Dict, Optional, Type, Union

from ote_anomalib import AnomalyNNCFTask, OpenVINOAnomalyTask
from ote_anomalib.data.dataset import (
Expand Down Expand Up @@ -61,13 +61,18 @@ def __init__(
val_subset: Dict[str, str],
test_subset: Dict[str, str],
model_template_path: str,
seed: Optional[int] = None,
) -> None:
"""Initialize OteAnomalyTask.
Args:
dataset_path (str): Path to the MVTec dataset.
seed (int): Seed to split the dataset into train/val/test splits.
train_subset (Dict[str, str]): Dictionary containing path to train annotation file and path to dataset.
val_subset (Dict[str, str]): Dictionary containing path to validation annotation file and path to dataset.
test_subset (Dict[str, str]): Dictionary containing path to test annotation file and path to dataset.
model_template_path (str): Path to model template.
seed (Optional[int]): Setting seed to a value other than 0 also marks PytorchLightning trainer's
deterministic flag to True.
Example:
>>> import os
Expand All @@ -78,9 +83,12 @@ def __init__(
>>> model_template_path = "./configs/anomaly_classification/padim/template.yaml"
>>> dataset_path = "./datasets/MVTec"
>>> seed = 0
>>> task = OteAnomalyTask(
... dataset_path=dataset_path, seed=seed, model_template_path=model_template_path
... dataset_path=dataset_path,
... train_subset={"ann_file": train.json, "data_root": dataset_path},
... val_subset={"ann_file": val.json, "data_root": dataset_path},
... test_subset={"ann_file": test.json, "data_root": dataset_path},
... model_template_path=model_template_path
... )
>>> task.train()
Expand Down Expand Up @@ -110,6 +118,7 @@ def __init__(
self.openvino_task: OpenVINOAnomalyTask
self.nncf_task: AnomalyNNCFTask
self.results = {"category": dataset_path}
self.seed = seed

def get_dataclass(
self,
Expand Down Expand Up @@ -176,9 +185,7 @@ def train(self) -> ModelEntity:
configuration=self.task_environment.get_model_configuration(),
)
self.torch_task.train(
dataset=self.dataset,
output_model=output_model,
train_parameters=TrainParameters(),
dataset=self.dataset, output_model=output_model, train_parameters=TrainParameters(), seed=self.seed
)

logger.info("Inferring the base torch model on the validation set.")
Expand Down Expand Up @@ -364,6 +371,7 @@ def main() -> None:
val_subset=val_subset,
test_subset=test_subset,
model_template_path=args.model_template_path,
seed=args.seed,
)

task.train()
Expand Down
2 changes: 1 addition & 1 deletion external/deep-object-reid/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
nncf @ git+https://github.com/openvinotoolkit/nncf@ed552bee19b1e40eaa2c06627acb928c1d6c2360#egg=nncf
nncf==2.2.0
openvino==2022.1.0
openvino-dev==2022.1.0
openmodelzoo-modelapi @ git+https://github.com/openvinotoolkit/open_model_zoo/@releases/2022/SCv1.1#egg=openmodelzoo-modelapi&subdirectory=demos/common/python
44 changes: 33 additions & 11 deletions external/mmdetection/detection_tasks/apis/detection/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
import os
import tempfile
from collections import defaultdict
from typing import List, Optional
from typing import List, Optional, Union

import torch
from mmcv import Config, ConfigDict
from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.label import LabelEntity, Domain
from ote_sdk.usecases.reporting.time_monitor_callback import TimeMonitorCallback
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
DirectoryPathCheck,
check_input_parameters_type
)

from detection_tasks.extension.datasets.data_utils import get_anchor_boxes, \
get_sizes_from_dataset_entity, format_list_to_str
Expand All @@ -43,14 +48,16 @@
logger = get_root_logger()


@check_input_parameters_type()
def is_epoch_based_runner(runner_config: ConfigDict):
return 'Epoch' in runner_config.type


@check_input_parameters_type({"work_dir": DirectoryPathCheck})
def patch_config(config: Config, work_dir: str, labels: List[LabelEntity], domain: Domain, random_seed: Optional[int] = None):
# Set runner if not defined.
if 'runner' not in config:
config.runner = {'type': 'EpochBasedRunner'}
config.runner = ConfigDict({'type': 'EpochBasedRunner'})

# Check that there is no conflict in specification of number of training epochs.
# Move global definition of epochs inside runner config.
Expand Down Expand Up @@ -112,6 +119,7 @@ def patch_config(config: Config, work_dir: str, labels: List[LabelEntity], domai
config.seed = random_seed


@check_input_parameters_type()
def set_hyperparams(config: Config, hyperparams: OTEDetectionConfig):
config.optimizer.lr = float(hyperparams.learning_parameters.learning_rate)
config.lr_config.warmup_iters = int(hyperparams.learning_parameters.learning_rate_warmup_iters)
Expand All @@ -126,7 +134,8 @@ def set_hyperparams(config: Config, hyperparams: OTEDetectionConfig):
config.runner.max_iters = total_iterations


def patch_adaptive_repeat_dataset(config: Config, num_samples: int,
@check_input_parameters_type()
def patch_adaptive_repeat_dataset(config: Union[Config, ConfigDict], num_samples: int,
decay: float = -0.002, factor: float = 30):
""" Patch the repeat times and training epochs adatively
Expand Down Expand Up @@ -155,14 +164,17 @@ def patch_adaptive_repeat_dataset(config: Config, num_samples: int,
data_train.times = new_repeat


def prepare_for_testing(config: Config, dataset: DatasetEntity) -> Config:
@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def prepare_for_testing(config: Union[Config, ConfigDict], dataset: DatasetEntity) -> Config:
config = copy.deepcopy(config)
# FIXME. Should working directories be modified here?
config.data.test.ote_dataset = dataset
return config


def prepare_for_training(config: Config, train_dataset: DatasetEntity, val_dataset: DatasetEntity,
@check_input_parameters_type({"train_dataset": DatasetParamTypeCheck,
"val_dataset": DatasetParamTypeCheck})
def prepare_for_training(config: Union[Config, ConfigDict], train_dataset: DatasetEntity, val_dataset: DatasetEntity,
time_monitor: TimeMonitorCallback, learning_curves: defaultdict) -> Config:
config = copy.deepcopy(config)
prepare_work_dir(config)
Expand All @@ -175,7 +187,8 @@ def prepare_for_training(config: Config, train_dataset: DatasetEntity, val_datas
return config


def config_to_string(config: Config) -> str:
@check_input_parameters_type()
def config_to_string(config: Union[Config, ConfigDict]) -> str:
"""
Convert a full mmdetection config to a string.
Expand All @@ -194,6 +207,7 @@ def config_to_string(config: Config) -> str:
return Config(config_copy).pretty_text


@check_input_parameters_type()
def config_from_string(config_string: str) -> Config:
"""
Generate an mmdetection config dict object from a string.
Expand All @@ -207,6 +221,7 @@ def config_from_string(config_string: str) -> Config:
return Config.fromfile(temp_file.name)


@check_input_parameters_type()
def save_config_to_file(config: Config):
""" Dump the full config to a file. Filename is 'config.py', it is saved in the current work_dir. """
filepath = os.path.join(config.work_dir, 'config.py')
Expand All @@ -215,7 +230,8 @@ def save_config_to_file(config: Config):
f.write(config_string)


def prepare_work_dir(config: Config) -> str:
@check_input_parameters_type()
def prepare_work_dir(config: Union[Config, ConfigDict]) -> str:
base_work_dir = config.work_dir
checkpoint_dirs = glob.glob(os.path.join(base_work_dir, "checkpoints_round_*"))
train_round_checkpoint_dir = os.path.join(base_work_dir, f"checkpoints_round_{len(checkpoint_dirs)}")
Expand All @@ -230,6 +246,7 @@ def prepare_work_dir(config: Config) -> str:
return train_round_checkpoint_dir


@check_input_parameters_type()
def set_data_classes(config: Config, labels: List[LabelEntity]):
# Save labels in data configs.
for subset in ('train', 'val', 'test'):
Expand All @@ -256,7 +273,8 @@ def set_data_classes(config: Config, labels: List[LabelEntity]):
# self.config.model.CLASSES = label_names


def patch_datasets(config: Config, domain):
@check_input_parameters_type()
def patch_datasets(config: Config, domain: Domain):

def patch_color_conversion(pipeline):
# Default data format for OTE is RGB, while mmdet uses BGR, so negate the color conversion flag.
Expand Down Expand Up @@ -289,7 +307,8 @@ def patch_color_conversion(pipeline):
patch_color_conversion(cfg.pipeline)


def remove_from_config(config, key: str):
@check_input_parameters_type()
def remove_from_config(config: Union[Config, ConfigDict], key: str):
if key in config:
if isinstance(config, Config):
del config._cfg_dict[key]
Expand All @@ -298,6 +317,8 @@ def remove_from_config(config, key: str):
else:
raise ValueError(f'Unknown config type {type(config)}')


@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector):
if not kmeans_import:
raise ImportError('Sklearn package is not installed. To enable anchor boxes clustering, please install '
Expand All @@ -308,7 +329,7 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
if transforms.type == 'MultiScaleFlipAug']
prev_generator = config.model.bbox_head.anchor_generator
group_as = [len(width) for width in prev_generator.widths]
wh_stats = get_sizes_from_dataset_entity(dataset, target_wh)
wh_stats = get_sizes_from_dataset_entity(dataset, list(target_wh))

if len(wh_stats) < sum(group_as):
logger.warning(f'There are not enough objects to cluster: {len(wh_stats)} were detected, while it should be '
Expand All @@ -332,7 +353,8 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
return config, model


def get_data_cfg(config: Config, subset: str = 'train') -> Config:
@check_input_parameters_type()
def get_data_cfg(config: Union[Config, ConfigDict], subset: str = 'train') -> Config:
data_cfg = config.data[subset]
while 'dataset' in data_cfg:
data_cfg = data_cfg.dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
from ote_sdk.usecases.tasks.interfaces.inference_interface import IInferenceTask
from ote_sdk.usecases.tasks.interfaces.unload_interface import IUnload
from ote_sdk.serialization.label_mapper import label_schema_to_bytes
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)

from mmdet.apis import export_model
from detection_tasks.apis.detection.config_utils import patch_config, prepare_for_testing, set_hyperparams
Expand All @@ -63,6 +67,7 @@ class OTEDetectionInferenceTask(IInferenceTask, IExportTask, IEvaluationTask, IU

_task_environment: TaskEnvironment

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
""""
Task for inference object detection models using OTEDetection.
Expand Down Expand Up @@ -239,6 +244,7 @@ def _add_predictions_to_dataset(self, prediction_results, dataset, confidence_th
dataset_item.append_metadata_item(active_score, model=self._task_environment.model)


@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None) -> DatasetEntity:
""" Analyzes a dataset using the latest inference model. """

Expand Down Expand Up @@ -330,7 +336,7 @@ def dummy_dump_features_hook(mod, inp, out):
eval_predictions = zip(eval_predictions, feature_vectors)
return eval_predictions, metric


@check_input_parameters_type()
def evaluate(self,
output_result_set: ResultSetEntity,
evaluation_metric: Optional[str] = None):
Expand Down Expand Up @@ -375,6 +381,7 @@ def unload(self):
logger.warning(f"Done unloading. "
f"Torch is still occupying {torch.cuda.memory_allocated()} bytes of GPU memory")

@check_input_parameters_type()
def export(self,
export_type: ExportType,
output_model: ModelEntity):
Expand Down
Loading

0 comments on commit 474317e

Please sign in to comment.