Skip to content

Commit

Permalink
Merge pull request #1084 from openvinotoolkit/class-incr-learning-val…
Browse files Browse the repository at this point in the history
…idation

Add model-preparation-algorithm API & training tests
  • Loading branch information
goodsong81 authored May 4, 2022
2 parents 31d3fdf + 4dc5d3a commit 18efbe5
Show file tree
Hide file tree
Showing 16 changed files with 1,579 additions and 63 deletions.
3 changes: 3 additions & 0 deletions external/mmsegmentation/tests/test_ote_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,5 +276,8 @@ def test(self,
test_parameters,
test_case_fx, data_collector_fx,
cur_test_expected_metrics_callback_fx):
if "18_OCR" in test_parameters["model_name"] \
or "x-mod3_OCR" in test_parameters["model_name"]:
pytest.skip("Known issue CVS-83781")
test_case_fx.run_stage(test_parameters['test_stage'], data_collector_fx,
cur_test_expected_metrics_callback_fx)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.inference_parameters import InferenceParameters, default_progress_callback
from ote_sdk.entities.train_parameters import default_progress_callback as train_default_progress_callback
from ote_sdk.entities.model import ModelEntity, ModelPrecision # ModelStatus
from ote_sdk.entities.resultset import ResultSetEntity
from mmcv.utils import ConfigDict
Expand Down Expand Up @@ -217,6 +218,7 @@ def cancel_training(self):
The stopping mechanism allows stopping after each iteration, but validation will still be carried out. Stopping
will therefore take some time.
"""
self._should_stop = True
logger.info("Cancel training requested.")
if self.cancel_interface is not None:
self.cancel_interface.cancel()
Expand All @@ -229,17 +231,34 @@ def train(self,
output_model: ModelEntity,
train_parameters: Optional[TrainParameters] = None):
logger.info('train()')
# Check for stop signal between pre-eval and training.
# If training is cancelled at this point,
if self._should_stop:
logger.info('Training cancelled.')
self._should_stop = False
self._is_training = False
return

# Set OTE LoggerHook & Time Monitor
update_progress_callback = default_progress_callback
update_progress_callback = train_default_progress_callback
if train_parameters is not None:
update_progress_callback = train_parameters.update_progress
self._time_monitor = TrainingProgressCallback(update_progress_callback)
self._learning_curves = defaultdict(OTELoggerHook.Curve)

stage_module = 'ClsTrainer'
self._data_cfg = self._init_train_data_cfg(dataset)
self._is_training = True
results = self._run_task(stage_module, mode='train', dataset=dataset, parameters=train_parameters)

# Check for stop signal between pre-eval and training.
# If training is cancelled at this point,
if self._should_stop:
logger.info('Training cancelled.')
self._should_stop = False
self._is_training = False
return

# get output model
model_ckpt = results.get('final_ckpt')
if model_ckpt is None:
Expand All @@ -257,6 +276,7 @@ def train(self,
dashboard_metrics=training_metrics)
logger.info(f'Final model performance: {str(performance)}')
output_model.performance = performance
self._is_training = False
logger.info('train done.')

def _init_train_data_cfg(self, dataset: DatasetEntity):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from mmcv.utils import ConfigDict
from detection_tasks.apis.detection.config_utils import remove_from_config
from detection_tasks.apis.detection.ote_utils import TrainingProgressCallback
from detection_tasks.apis.detection.ote_utils import TrainingProgressCallback, InferenceProgressCallback
from detection_tasks.extension.utils.hooks import OTELoggerHook
from mpa_tasks.apis import BaseTask, TrainType
from mpa_tasks.apis.detection import DetectionConfig
Expand Down Expand Up @@ -67,6 +67,11 @@ def infer(self,
) -> DatasetEntity:
logger.info('infer()')

update_progress_callback = default_progress_callback
if inference_parameters is not None:
update_progress_callback = inference_parameters.update_progress

self._time_monitor = InferenceProgressCallback(len(dataset), update_progress_callback)
# If confidence threshold is adaptive then up-to-date value should be stored in the model
# and should not be changed during inference. Otherwise user-specified value should be taken.
if not self._hyperparams.postprocessing.result_based_confidence_threshold:
Expand All @@ -75,7 +80,7 @@ def infer(self,

stage_module = 'DetectionInferrer'
self._data_cfg = self._init_test_data_cfg(dataset)
results = self._run_task(stage_module, mode='train', dataset=dataset)
results = self._run_task(stage_module, mode='train', dataset=dataset, parameters=inference_parameters)
# TODO: InferenceProgressCallback register
logger.debug(f'result of run_task {stage_module} module = {results}')
output = results['outputs']
Expand Down Expand Up @@ -310,7 +315,7 @@ def cancel_training(self):
will therefore take some time.
"""
logger.info("Cancel training requested.")
# self._should_stop = True
self._should_stop = True
# stop_training_filepath = os.path.join(self._training_work_dir, '.stop_training')
# open(stop_training_filepath, 'a').close()
if self.cancel_interface is not None:
Expand All @@ -324,6 +329,14 @@ def train(self,
output_model: ModelEntity,
train_parameters: Optional[TrainParameters] = None):
logger.info('train()')
# Check for stop signal when training has stopped.
# If should_stop is true, training was cancelled and no new
if self._should_stop:
logger.info('Training cancelled.')
self._should_stop = False
self._is_training = False
return

# Set OTE LoggerHook & Time Monitor
update_progress_callback = default_progress_callback
if train_parameters is not None:
Expand All @@ -333,8 +346,15 @@ def train(self,

stage_module = 'DetectionTrainer'
self._data_cfg = self._init_train_data_cfg(dataset)
self._is_training = True
results = self._run_task(stage_module, mode='train', dataset=dataset, parameters=train_parameters)
# logger.info(f'result of run_task {stage_module} module = {results}')

# Check for stop signal when training has stopped. If should_stop is true, training was cancelled and no new
if self._should_stop:
logger.info('Training cancelled.')
self._should_stop = False
self._is_training = False
return

# get output model
model_ckpt = results.get('final_ckpt')
Expand Down Expand Up @@ -389,6 +409,7 @@ def train(self,
self.save_model(output_model)
output_model.performance = performance
# output_model.model_status = ModelStatus.SUCCESS
self._is_training = False
logger.info('train done.')

def _init_train_data_cfg(self, dataset: DatasetEntity):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from mmcv.utils import ConfigDict
from segmentation_tasks.apis.segmentation.config_utils import remove_from_config
from segmentation_tasks.apis.segmentation.ote_utils import TrainingProgressCallback
from segmentation_tasks.apis.segmentation.ote_utils import TrainingProgressCallback, InferenceProgressCallback
from segmentation_tasks.extension.utils.hooks import OTELoggerHook
from mpa import MPAConstants
from mpa_tasks.apis import BaseTask, TrainType
Expand All @@ -22,6 +22,7 @@
from ote_sdk.configuration.helper.utils import ids_to_strings
from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.inference_parameters import InferenceParameters
from ote_sdk.entities.inference_parameters import default_progress_callback as default_infer_progress_callback
from ote_sdk.entities.label import Domain
from ote_sdk.entities.metrics import (CurveMetric, InfoMetric, LineChartInfo,
MetricsGroup, Performance, ScoreMetric,
Expand All @@ -48,8 +49,6 @@
create_annotation_from_segmentation_map,
create_hard_prediction_from_soft_prediction)

# from mmdet.apis import export_model


logger = get_logger()

Expand All @@ -70,12 +69,14 @@ def infer(self,
logger.info('infer()')

if inference_parameters is not None:
# update_progress_callback = inference_parameters.update_progress
update_progress_callback = inference_parameters.update_progress
is_evaluation = inference_parameters.is_evaluation
else:
# update_progress_callback = default_infer_progress_callback
update_progress_callback = default_infer_progress_callback
is_evaluation = False

self._time_monitor = InferenceProgressCallback(len(dataset), update_progress_callback)

stage_module = 'SegInferrer'
self._data_cfg = self._init_test_data_cfg(dataset)
self._label_dictionary = dict(enumerate(self._labels, 1))
Expand Down Expand Up @@ -187,8 +188,10 @@ def _init_test_data_cfg(self, dataset: DatasetEntity):
data_cfg = ConfigDict(
data=ConfigDict(
train=ConfigDict(
ote_dataset=None,
labels=self._labels,
dataset=ConfigDict(
ote_dataset=None,
labels=self._labels,
)
),
test=ConfigDict(
ote_dataset=dataset,
Expand Down Expand Up @@ -311,7 +314,7 @@ def cancel_training(self):
will therefore take some time.
"""
logger.info("Cancel training requested.")
# self._should_stop = True
self._should_stop = True
# stop_training_filepath = os.path.join(self._training_work_dir, '.stop_training')
# open(stop_training_filepath, 'a').close()
if self.cancel_interface is not None:
Expand All @@ -325,6 +328,14 @@ def train(self,
output_model: ModelEntity,
train_parameters: Optional[TrainParameters] = None):
logger.info('train()')
# Check for stop signal between pre-eval and training.
# If training is cancelled at this point,
if self._should_stop:
logger.info('Training cancelled.')
self._should_stop = False
self._is_training = False
return

# Set OTE LoggerHook & Time Monitor
if train_parameters is not None:
update_progress_callback = train_parameters.update_progress
Expand All @@ -336,8 +347,17 @@ def train(self,
# learning_curves = defaultdict(OTELoggerHook.Curve)
stage_module = 'SegTrainer'
self._data_cfg = self._init_train_data_cfg(dataset)
self._is_training = True
results = self._run_task(stage_module, mode='train', dataset=dataset, parameters=train_parameters)

# Check for stop signal when training has stopped.
# If should_stop is true, training was cancelled and no new
if self._should_stop:
logger.info('Training cancelled.')
self._should_stop = False
self._is_training = False
return

# get output model
model_ckpt = results.get('final_ckpt')
if model_ckpt is None:
Expand All @@ -358,6 +378,7 @@ def train(self,
self.save_model(output_model)
output_model.performance = performance
# output_model.model_status = ModelStatus.SUCCESS
self._is_training = False
logger.info('train done.')

def _init_train_data_cfg(self, dataset: DatasetEntity):
Expand Down
54 changes: 11 additions & 43 deletions external/model-preparation-algorithm/mpa_tasks/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,6 @@
logger = get_logger()


class _MPAUpdateProgressCallbackWrapper(UpdateProgressCallback):
""" UpdateProgressCallback wrapper
just wrapping the callback instance and provides error free representation as 'pretty_text'
"""

def __init__(self, callback, **kwargs):
if not callable(callback):
raise RuntimeError(f'cannot accept a not callable object!! {callback}')
self._callback = callback
super().__init__(**kwargs)

def __repr__(self):
return f"'{__name__}._MPAUpdateProgressCallbackWrapper'"

def __reduce__(self):
return (self.__class__, (id(self),))

def __call__(self, progress: float, score: Optional[float] = None):
self._callback(progress, score)


class BaseTask:
def __init__(self, task_config, task_environment: TaskEnvironment):
self._task_config = task_config
Expand Down Expand Up @@ -83,6 +62,8 @@ def __init__(self, task_config, task_environment: TaskEnvironment):
self._mode = None
self._time_monitor = None
self._learning_curves = None
self._is_training = False
self._should_stop = False
self.cancel_interface = None
self.reserved_cancel = False
self.on_hook_initialized = self.OnHookInitialized(self)
Expand All @@ -104,30 +85,9 @@ def _run_task(self, stage_module, mode=None, dataset=None, parameters=None, **kw
raise RuntimeError(
"'recipe_cfg' is not initialized yet."
"call prepare() method before calling this method")
# self._stage_module = stage_module

if mode is not None:
self._mode = mode
if parameters is not None:
if isinstance(parameters, TrainParameters):
hook_name = 'TrainProgressUpdateHook'
progress_callback = _MPAUpdateProgressCallbackWrapper(parameters.update_progress)
# TODO: update recipe to do RESUME
if parameters.resume:
pass
elif isinstance(parameters, InferenceParameters):
hook_name = 'InferenceProgressUpdateHook'
progress_callback = _MPAUpdateProgressCallbackWrapper(parameters.update_progress)
else:
hook_name = 'ProgressUpdateHook'
progress_callback = None
logger.info(f'progress callback = {progress_callback}, hook name = {hook_name}')
if progress_callback is not None:
progress_update_hook_cfg = ConfigDict(
type='ProgressUpdateHook',
name=hook_name,
callback=progress_callback
)
update_or_add_custom_hook(self._recipe_cfg, progress_update_hook_cfg)

common_cfg = ConfigDict(dict(output_path=self._output_path))

Expand All @@ -152,6 +112,14 @@ def finalize(self):
if os.path.exists(self._output_path):
shutil.rmtree(self._output_path, ignore_errors=False)

def _delete_scratch_space(self):
"""
Remove model checkpoints and mpa logs
"""

if os.path.exists(self._output_path):
shutil.rmtree(self._output_path, ignore_errors=False)

def __del__(self):
self.finalize()

Expand Down
2 changes: 1 addition & 1 deletion external/model-preparation-algorithm/ote_tests_pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[pytest]
python_files = test_*_cls_il.py
python_files = test_ote_*.py
11 changes: 11 additions & 0 deletions external/model-preparation-algorithm/tests/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

try:
import os
from e2e import config as config_e2e

config_e2e.repository_name = os.environ.get('TT_REPOSITORY_NAME', 'ote/training_extensions/external/model-preparation-algorithm')
except ImportError:
pass
Loading

0 comments on commit 18efbe5

Please sign in to comment.