Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OTE_SDK] expand ModelTemplate.is_global #980

Merged
merged 2 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ote_sdk/ote_sdk/entities/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,10 @@ def is_task_global(self) -> bool:
"""
Returns ``True`` if the task is global task i.e. if task produces global labels
"""
return self.task_type in [TaskType.CLASSIFICATION]
return self.task_type in (
TaskType.CLASSIFICATION,
TaskType.ANOMALY_CLASSIFICATION,
)


class NullModelTemplate(ModelTemplate):
Expand Down
33 changes: 21 additions & 12 deletions ote_sdk/ote_sdk/tests/entities/test_model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,27 +965,36 @@ def test_model_template_is_task_global(self):
Test passes if is_task_global method of ModelTemplate object returns expected bool values related to
task_type attribute
<b>Steps</b>
1. Check is_task_global method returns True if task_type equal to CLASSIFICATION
2. Check is_task_global method returns False if task_type not equal to CLASSIFICATION
1. Check is_task_global method returns True if task_type equal to CLASSIFICATION or ANOMALY_CLASSIFICATION
2. Check is_task_global method returns False if task_type not equal to CLASSIFICATION or ANOMALY_CLASSIFICATION
"""
# Check is_task_global method returns True
default_parameters = self.default_model_parameters()
task_global_parameters = dict(default_parameters)
task_global_parameters["task_type"] = TaskType.CLASSIFICATION
task_global_model_template = ModelTemplate(**task_global_parameters)
assert task_global_model_template.is_task_global()
# Check is_task_global method returns False
# Check is_task_global method returns True for CLASSIFICATION and ANOMALY_CLASSIFICATION
for global_task_type in (
TaskType.CLASSIFICATION,
TaskType.ANOMALY_CLASSIFICATION,
):
default_parameters = self.default_model_parameters()
task_global_parameters = dict(default_parameters)
task_global_parameters["task_type"] = global_task_type
task_global_model_template = ModelTemplate(**task_global_parameters)
assert (
task_global_model_template.is_task_global()
), f"Expected True value returned by is_task_global for {global_task_type}"
# Check is_task_global method returns False for the other tasks
non_global_task_parameters = dict(default_parameters)
non_global_tasks_list = []
for task_type in TaskType:
if task_type != TaskType.CLASSIFICATION:
if task_type not in (
TaskType.CLASSIFICATION,
TaskType.ANOMALY_CLASSIFICATION,
):
non_global_tasks_list.append(task_type)
for non_global_task in non_global_tasks_list:
non_global_task_parameters["task_type"] = non_global_task
non_global_task_template = ModelTemplate(**non_global_task_parameters)
assert not non_global_task_template.is_task_global(), (
f"Expected False value returned by is_task_global method for {non_global_task}, only CLASSIFICATION "
f"task type is global"
f"Expected False value returned by is_task_global method for {non_global_task}, "
f"only CLASSIFICATION and ANOMALY_CLASSIFICATION task types are global"
)


Expand Down