Skip to content

Commit

Permalink
fix tests + update config names
Browse files Browse the repository at this point in the history
Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>
  • Loading branch information
ashwinvaidya17 committed Mar 14, 2024
1 parent 7cb76e9 commit 11b2661
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/otx/engine/utils/auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@
OTXTaskType.VISUAL_PROMPTING: "otx.core.model.entity.visual_prompting.OVVisualPromptingModel",
OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: "otx.core.model.entity.visual_prompting.OVZeroShotVisualPromptingModel",
OTXTaskType.ACTION_CLASSIFICATION: "otx.core.model.entity.action_classification.OVActionClsModel",
OTXTaskType.ANOMALY_CLASSIFICATION: "otx.algo.anomaly.anomaly_openvino.AnomalyOpenVINO",
OTXTaskType.ANOMALY_DETECTION: "otx.algo.anomaly.anomaly_openvino.AnomalyOpenVINO",
OTXTaskType.ANOMALY_SEGMENTATION: "otx.algo.anomaly.anomaly_openvino.AnomalyOpenVINO",
}


Expand Down
17 changes: 11 additions & 6 deletions tests/integration/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
import pytest
import yaml
from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK

from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK
from tests.integration.cli.utils import run_main


Expand Down Expand Up @@ -129,9 +129,6 @@ def test_otx_e2e(
"dino_v2",
"instance_segmentation",
"action",
"anomaly_classification",
"anomaly_detection",
"anomaly_segmentation",
]
):
return
Expand All @@ -149,6 +146,10 @@ def test_otx_e2e(
"EXPORTABLE_CODE": "exportable_code.zip",
}

overrides = fxt_cli_override_command_per_task[task]
if "anomaly" in task:
overrides = {} # Overrides are not needed in export

tmp_path_test = tmp_path / f"otx_test_{model_name}"
for fmt in format_to_file:
command_cfg = [
Expand All @@ -160,7 +161,7 @@ def test_otx_e2e(
fxt_target_dataset_per_task[task],
"--work_dir",
str(tmp_path_test / "outputs" / fmt),
*fxt_cli_override_command_per_task[task],
*overrides,
"--checkpoint",
str(ckpt_files[-1]),
"--export_format",
Expand All @@ -185,6 +186,10 @@ def test_otx_e2e(
)
exported_model_path = str(ov_latest_dir / "exported_model.xml")

overrides = fxt_cli_override_command_per_task[task]
if "anomaly" in task:
overrides = {} # Overrides are not needed in infer

command_cfg = [
"otx",
"test",
Expand All @@ -196,7 +201,7 @@ def test_otx_e2e(
str(tmp_path_test / "outputs"),
"--engine.device",
"cpu",
*fxt_cli_override_command_per_task[task],
*overrides,
"--checkpoint",
exported_model_path,
]
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/cli/test_export_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def fxt_local_seed() -> int:


TASK_NAME_TO_MAIN_METRIC_NAME = {
"anomaly_classification": "test/accuracy",
"anomaly_segmentation": "test/accuracy",
"anomaly_detection": "test/accuracy",
"semantic_segmentation": "test/Dice",
"multi_label_cls": "test/accuracy",
"multi_class_cls": "test/accuracy",
Expand Down

0 comments on commit 11b2661

Please sign in to comment.