Skip to content

Commit

Permalink
Merge pull request #919 from openvinotoolkit/es/hpo_segmentation
Browse files Browse the repository at this point in the history
[HPO] enable HPO with segmentation
  • Loading branch information
Ilya-Krylov authored Feb 22, 2022
2 parents e8d24d0 + 14c1d94 commit c1c3753
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
34 changes: 32 additions & 2 deletions ote_cli/ote_cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def check_hpopt_available():
def run_hpo(args, environment, dataset, task_type):
"""Update the environment with better hyper-parameters found by HPO"""
if check_hpopt_available():
if task_type not in {TaskType.CLASSIFICATION, TaskType.DETECTION}:
if task_type not in {
TaskType.CLASSIFICATION,
TaskType.DETECTION,
TaskType.SEGMENTATION,
}:
print(
"Currently supported task types are classification and detection."
f"{task_type} is not supported yet."
Expand Down Expand Up @@ -135,8 +139,33 @@ def run_hpo_trainer(
# set epoch
if task_type == TaskType.CLASSIFICATION:
(hyper_parameters.learning_parameters.max_num_epochs) = hp_config["iterations"]
elif task_type in (TaskType.DETECTION, TaskType.SEGMENTATION):
elif task_type == TaskType.DETECTION:
hyper_parameters.learning_parameters.num_iters = hp_config["iterations"]
elif task_type == TaskType.SEGMENTATION:
eph_comp = [
hyper_parameters.learning_parameters.learning_rate_fixed_iters,
hyper_parameters.learning_parameters.learning_rate_warmup_iters,
hyper_parameters.learning_parameters.num_iters,
]

eph_comp = list(
map(lambda x: x * hp_config["iterations"] / sum(eph_comp), eph_comp)
)

for val in sorted(
list(range(len(eph_comp))),
key=lambda k: eph_comp[k] - int(eph_comp[k]),
reverse=True,
)[: hp_config["iterations"] - sum(map(int, eph_comp))]:
eph_comp[val] += 1

hyper_parameters.learning_parameters.learning_rate_fixed_iters = int(
eph_comp[0]
)
hyper_parameters.learning_parameters.learning_rate_warmup_iters = int(
eph_comp[1]
)
hyper_parameters.learning_parameters.num_iters = int(eph_comp[2])

# set hyper-parameters and print them
HpoManager.set_hyperparameter(hyper_parameters, hp_config["params"])
Expand Down Expand Up @@ -630,6 +659,7 @@ def find_class(self, module_name, class_name):
def main():
"""Run run_hpo_trainer with a pickle file"""
hp_config = None
sys.path[0] = "" # to prevent importing nncf from this directory

try:
with open(sys.argv[1], "rb") as pfile:
Expand Down
6 changes: 6 additions & 0 deletions tests/ote_cli/test_ote_cli_tools_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ote_eval_deployment_testing,
ote_eval_openvino_testing,
ote_eval_testing,
ote_hpo_testing,
ote_train_testing,
ote_export_testing,
pot_optimize_testing,
Expand Down Expand Up @@ -118,6 +119,11 @@ def test_ote_eval_deployment(self, template):
def test_ote_demo_deployment(self, template):
ote_demo_deployment_testing(template, root, ote_dir, args)

@e2e_pytest_component
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_ote_hpo(self, template):
ote_hpo_testing(template, root, ote_dir, args)

@e2e_pytest_component
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_nncf_optimize(self, template):
Expand Down

0 comments on commit c1c3753

Please sign in to comment.