Skip to content

Commit

Permalink
fix nncf precision config to exported model
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjuleee committed Jun 2, 2022
1 parent eed9902 commit 0957213
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, task_environment: TaskEnvironment):

# Set default model attributes.
self._optimization_methods = []
self._precision = [ModelPrecision.FP16] if self._config.get('fp16', None) else [ModelPrecision.FP32]
self._precision = self._precision_from_config
self._optimization_type = ModelOptimizationType.MO

# Create and initialize PyTorch model.
Expand All @@ -113,6 +113,10 @@ def __init__(self, task_environment: TaskEnvironment):
self._should_stop = False
logger.info('Task initialization completed')

@property
def _precision_from_config(self):
return [ModelPrecision.FP16] if self._config.get('fp16', None) else [ModelPrecision.FP32]

@property
def _hyperparams(self):
return self._task_environment.get_hyper_parameters(OTEDetectionConfig)
Expand Down Expand Up @@ -402,7 +406,8 @@ def export(self,
model = self._model.cpu()
pruning_transformation = OptimizationMethod.FILTER_PRUNING in self._optimization_methods
export_model(model, self._config, tempdir, target='openvino',
pruning_transformation=pruning_transformation, precision=self._precision[0].name)
pruning_transformation=pruning_transformation,
precision=self._precision_from_config[0].name)
bin_file = [f for f in os.listdir(tempdir) if f.endswith('.bin')][0]
xml_file = [f for f in os.listdir(tempdir) if f.endswith('.xml')][0]
with open(os.path.join(tempdir, bin_file), "rb") as f:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ def _set_attributes_by_hyperparams(self):
if quantization and pruning:
self._nncf_preset = "nncf_quantization_pruning"
self._optimization_methods = [OptimizationMethod.QUANTIZATION, OptimizationMethod.FILTER_PRUNING]
self._nncf_precision = [ModelPrecision.INT8]
self._precision = [ModelPrecision.INT8]
return
if quantization and not pruning:
self._nncf_preset = "nncf_quantization"
self._optimization_methods = [OptimizationMethod.QUANTIZATION]
self._nncf_precision = [ModelPrecision.INT8]
self._precision = [ModelPrecision.INT8]
return
if not quantization and pruning:
self._nncf_preset = "nncf_pruning"
self._optimization_methods = [OptimizationMethod.FILTER_PRUNING]
self._nncf_precision = [ModelPrecision.FP32]
self._precision = [ModelPrecision.INT8]
return
raise RuntimeError('Not selected optimization algorithm')

Expand Down Expand Up @@ -249,7 +249,7 @@ def optimize(
output_model.model_format = ModelFormat.BASE_FRAMEWORK
output_model.optimization_type = self._optimization_type
output_model.optimization_methods = self._optimization_methods
output_model.precision = self._nncf_precision
output_model.precision = self._precision

self._is_training = False

Expand Down

0 comments on commit 0957213

Please sign in to comment.