Skip to content

Commit

Permalink
Merge pull request #1122 from openvinotoolkit/CVS-85318-nncf_precision
Browse files Browse the repository at this point in the history
CVS-85318 nncf precision
  • Loading branch information
goodsong81 authored Jun 3, 2022
2 parents cc11743 + 561daf0 commit f5d8097
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, task_environment: TaskEnvironment):

# Set default model attributes.
self._optimization_methods = []
self._precision = [ModelPrecision.FP16] if self._config.get('fp16', None) else [ModelPrecision.FP32]
self._precision = self._precision_from_config
self._optimization_type = ModelOptimizationType.MO

# Create and initialize PyTorch model.
Expand All @@ -113,6 +113,10 @@ def __init__(self, task_environment: TaskEnvironment):
self._should_stop = False
logger.info('Task initialization completed')

@property
def _precision_from_config(self):
return [ModelPrecision.FP16] if self._config.get('fp16', None) else [ModelPrecision.FP32]

@property
def _hyperparams(self):
return self._task_environment.get_hyper_parameters(OTEDetectionConfig)
Expand Down Expand Up @@ -402,7 +406,8 @@ def export(self,
model = self._model.cpu()
pruning_transformation = OptimizationMethod.FILTER_PRUNING in self._optimization_methods
export_model(model, self._config, tempdir, target='openvino',
pruning_transformation=pruning_transformation, precision=self._precision[0].name)
pruning_transformation=pruning_transformation,
precision=self._precision_from_config[0].name)
bin_file = [f for f in os.listdir(tempdir) if f.endswith('.bin')][0]
xml_file = [f for f in os.listdir(tempdir) if f.endswith('.xml')][0]
with open(os.path.join(tempdir, bin_file), "rb") as f:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ def _set_attributes_by_hyperparams(self):
if quantization and pruning:
self._nncf_preset = "nncf_quantization_pruning"
self._optimization_methods = [OptimizationMethod.QUANTIZATION, OptimizationMethod.FILTER_PRUNING]
self._nncf_precision = [ModelPrecision.INT8]
self._precision = [ModelPrecision.INT8]
return
if quantization and not pruning:
self._nncf_preset = "nncf_quantization"
self._optimization_methods = [OptimizationMethod.QUANTIZATION]
self._nncf_precision = [ModelPrecision.INT8]
self._precision = [ModelPrecision.INT8]
return
if not quantization and pruning:
self._nncf_preset = "nncf_pruning"
self._optimization_methods = [OptimizationMethod.FILTER_PRUNING]
self._nncf_precision = [ModelPrecision.FP32]
self._precision = self._precision_from_config
return
raise RuntimeError('Not selected optimization algorithm')

Expand Down Expand Up @@ -249,7 +249,7 @@ def optimize(
output_model.model_format = ModelFormat.BASE_FRAMEWORK
output_model.optimization_type = self._optimization_type
output_model.optimization_methods = self._optimization_methods
output_model.precision = self._nncf_precision
output_model.precision = self._precision

self._is_training = False

Expand Down
10 changes: 9 additions & 1 deletion external/mmdetection/tests/test_ote_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ote_sdk.entities.inference_parameters import InferenceParameters
from ote_sdk.entities.model_template import TaskType, task_type_to_label_domain
from ote_sdk.entities.metrics import Performance
from ote_sdk.entities.model import ModelEntity, ModelFormat, ModelOptimizationType
from ote_sdk.entities.model import ModelEntity, ModelFormat, ModelOptimizationType, ModelPrecision
from ote_sdk.entities.model_template import parse_model_template
from ote_sdk.entities.optimization_parameters import OptimizationParameters
from ote_sdk.entities.resultset import ResultSetEntity
Expand Down Expand Up @@ -520,6 +520,14 @@ def end_to_end(
print(f'Performance of NNCF model: {nncf_performance.score.value:.4f}')
self.check_threshold(validation_performance, nncf_performance, nncf_perf_delta_tolerance,
'Too big performance difference after NNCF optimization.')

# Check whether optimize & export assigns correct model precision
nncf_task.export(ExportType.OPENVINO, nncf_model)

if nncf_task._hyperparams.nncf_optimization.enable_quantization:
assert nncf_model.precision[0] == ModelPrecision.INT8
else:
assert nncf_model.precision[0] == nncf_task._precision_from_config[0]
else:
print('Skipped test of OTEDetectionNNCFTask. Required NNCF module.')

Expand Down

0 comments on commit f5d8097

Please sign in to comment.