diff --git a/ote_sdk/ote_sdk/utils/argument_checks.py b/ote_sdk/ote_sdk/utils/argument_checks.py index 30a2b40d70a..1733cbc6a02 100644 --- a/ote_sdk/ote_sdk/utils/argument_checks.py +++ b/ote_sdk/ote_sdk/utils/argument_checks.py @@ -7,30 +7,84 @@ # import inspect +import itertools import typing from abc import ABC, abstractmethod from collections.abc import Sequence from functools import wraps -from os.path import exists +from os.path import exists, splitext import yaml from numpy import floating from omegaconf import DictConfig +IMAGE_FILE_EXTENSIONS = [ + ".bmp", + ".dib", + ".jpeg", + ".jpg", + ".jpe", + ".jp2", + ".png", + ".webp", + ".pbm", + ".pgm", + ".ppm", + ".pxm", + ".pnm", + ".sr", + ".ras", + ".tiff", + ".tif", + ".exr", + ".hdr", + ".pic", +] + + +def get_bases(parameter) -> set: + """Function to get set of all base classes of parameter""" + + def __get_bases(parameter_type): + return [parameter_type.__name__] + list( + itertools.chain.from_iterable( + __get_bases(t1) for t1 in parameter_type.__bases__ + ) + ) + + return set(__get_bases(type(parameter))) + + +def get_parameter_repr(parameter) -> str: + """Function to get parameter representation""" + try: + parameter_str = repr(parameter) + # pylint: disable=broad-except + except Exception: + parameter_str = "" + return parameter_str + def raise_value_error_if_parameter_has_unexpected_type( parameter, parameter_name, expected_type ): """Function raises ValueError exception if parameter has unexpected type""" + if isinstance(expected_type, typing.ForwardRef): + expected_type = expected_type.__forward_arg__ + if isinstance(expected_type, str): + parameter_types = get_bases(parameter) + if not any(t == expected_type for t in parameter_types): + parameter_str = get_parameter_repr(parameter) + raise ValueError( + f"Unexpected type of '{parameter_name}' parameter, expected: {expected_type}, " + f"actual value: {parameter_str}" + ) + return if expected_type == float: expected_type = (int, float, floating) if not isinstance(parameter, expected_type): parameter_type = type(parameter) - try: - parameter_str = repr(parameter) - # pylint: disable=broad-except - except Exception: - parameter_str = "" + parameter_str = get_parameter_repr(parameter) raise ValueError( f"Unexpected type of '{parameter_name}' parameter, expected: {expected_type}, actual: {parameter_type}, " f"actual value: {parameter_str}" @@ -67,21 +121,10 @@ def check_dictionary_keys_values_type( ) -def check_parameter_type(parameter, parameter_name, expected_type): - """Function extracts nested expected types and raises ValueError exception if parameter has unexpected type""" - # pylint: disable=W0212 - if expected_type in [typing.Any, inspect._empty]: # type: ignore - return - if not isinstance(expected_type, typing._GenericAlias): # type: ignore - raise_value_error_if_parameter_has_unexpected_type( - parameter=parameter, - parameter_name=parameter_name, - expected_type=expected_type, - ) - return - expected_type_dict = expected_type.__dict__ - origin_class = expected_type_dict.get("__origin__") - nested_elements_class = expected_type_dict.get("__args__") +def check_nested_classes_parameters( + parameter, parameter_name, origin_class, nested_elements_class +): + """Function to check type of parameters with nested elements""" if origin_class == dict: if len(nested_elements_class) != 2: raise TypeError( @@ -100,18 +143,53 @@ def check_parameter_type(parameter, parameter_name, expected_type): parameter_name=parameter_name, expected_type=origin_class, ) - if len(nested_elements_class) != 1: - raise TypeError( - "length of nested expected types for Sequence should be equal to 1" - ) + if origin_class == tuple: + tuple_length = len(nested_elements_class) + if tuple_length > 2: + raise NotImplementedError( + "length of nested expected types for Tuple should not exceed 2" + ) + if tuple_length == 2: + if nested_elements_class[1] != Ellipsis: + raise NotImplementedError("expected homogeneous tuple annotation") + nested_elements_class = nested_elements_class[0] + else: + if len(nested_elements_class) != 1: + raise TypeError( + "length of nested expected types for Sequence should be equal to 1" + ) check_nested_elements_type( iterable=parameter, parameter_name=parameter_name, expected_type=nested_elements_class, ) + + +def check_parameter_type(parameter, parameter_name, expected_type): + """Function extracts nested expected types and raises ValueError exception if parameter has unexpected type""" + # pylint: disable=W0212 + if expected_type in [typing.Any, inspect._empty]: # type: ignore + return + if not isinstance(expected_type, typing._GenericAlias): # type: ignore + raise_value_error_if_parameter_has_unexpected_type( + parameter=parameter, + parameter_name=parameter_name, + expected_type=expected_type, + ) + return + # Checking parameters with nested elements + expected_type_dict = expected_type.__dict__ + origin_class = expected_type_dict.get("__origin__") + nested_elements_class = expected_type_dict.get("__args__") + check_nested_classes_parameters( + parameter=parameter, + parameter_name=parameter_name, + origin_class=origin_class, + nested_elements_class=nested_elements_class, + ) + # Union type with nested elements check if origin_class == typing.Union: expected_args = expected_type_dict.get("__args__") - # Union type with nested elements check checks_counter = 0 errors_counter = 0 for expected_arg in expected_args: @@ -128,10 +206,13 @@ def check_parameter_type(parameter, parameter_name, expected_type): ) -def check_input_parameters_type(checks_types: dict = None): - """Decorator to check input parameters type""" - if checks_types is None: - checks_types = {} +def check_input_parameters_type(custom_checks: typing.Optional[dict] = None): + """ + Decorator to check input parameters type + :param custom_checks: dictionary where key - name of parameter and value - custom check class + """ + if custom_checks is None: + custom_checks = {} def _check_input_parameters_type(function): @wraps(function) @@ -150,21 +231,23 @@ def validate(*args, **kwargs): ) input_parameters_values_map[key] = value # Checking input parameters type - for parameter in expected_types_map: - input_parameter_actual = input_parameters_values_map.get(parameter) - if input_parameter_actual is None: - default_value = expected_types_map.get(parameter).default + for parameter_name in expected_types_map: + parameter = input_parameters_values_map.get(parameter_name) + if parameter is None: + default_value = expected_types_map.get(parameter_name).default # pylint: disable=protected-access if default_value != inspect._empty: # type: ignore - input_parameter_actual = default_value - custom_check = checks_types.get(parameter) - if custom_check: - custom_check(input_parameter_actual, parameter).check() + parameter = default_value + if parameter_name in custom_checks: + custom_check = custom_checks[parameter_name] + if custom_check is None: + continue + custom_check(parameter, parameter_name).check() else: check_parameter_type( - parameter=input_parameter_actual, - parameter_name=parameter, - expected_type=expected_types_map.get(parameter).annotation, + parameter=parameter, + parameter_name=parameter_name, + expected_type=expected_types_map.get(parameter_name).annotation, ) return function(**input_parameters_values_map) @@ -177,7 +260,7 @@ def check_file_extension( file_path: str, file_path_name: str, expected_extensions: list ): """Function raises ValueError exception if file has unexpected extension""" - file_extension = file_path.split(".")[-1].lower() + file_extension = splitext(file_path)[1].lower() if file_extension not in expected_extensions: raise ValueError( f"Unexpected extension of {file_path_name} file. expected: {expected_extensions} actual: {file_extension}" @@ -314,7 +397,7 @@ def check(self): check_file_extension( file_path=self.parameter, file_path_name=self.parameter_name, - expected_extensions=["yaml"], + expected_extensions=[".yaml"], ) check_that_all_characters_printable( parameter=self.parameter, parameter_name=self.parameter_name @@ -363,46 +446,12 @@ def __init__(self, parameter, parameter_name): self.parameter_name = parameter_name def check(self): - """Method raises ValueError exception if parameter is not equal to DataSet""" + """Method raises ValueError exception if parameter is not equal to Dataset""" check_is_parameter_like_dataset( parameter=self.parameter, parameter_name=self.parameter_name ) -class OptionalDatasetParamTypeCheck(DatasetParamTypeCheck): - """Class to check DatasetEntity-type parameters""" - - def check(self): - """Method raises ValueError exception if parameter is not equal to DataSet""" - if self.parameter is not None: - check_is_parameter_like_dataset( - parameter=self.parameter, parameter_name=self.parameter_name - ) - - -class OptionalModelParamTypeCheck(BaseInputArgumentChecker): - """Class to check ModelEntity-type parameters""" - - def __init__(self, parameter, parameter_name): - self.parameter = parameter - self.parameter_name = parameter_name - - def check(self): - """Method raises ValueError exception if parameter is not equal to DataSet""" - if self.parameter is not None: - for expected_attribute in ( - "__train_dataset__", - "__previous_trained_revision__", - "__model_format__", - ): - if not hasattr(self.parameter, expected_attribute): - parameter_type = type(self.parameter) - raise ValueError( - f"parameter '{self.parameter_name}' is not like ModelEntity, actual type: {parameter_type} " - f"which does not have expected '{expected_attribute}' Model attribute" - ) - - class OptionalImageFilePathCheck(OptionalFilePathCheck): """Class to check optional image file path parameters""" @@ -410,7 +459,7 @@ def __init__(self, parameter, parameter_name): super().__init__( parameter=parameter, parameter_name=parameter_name, - expected_file_extension=["jpg", "png"], + expected_file_extension=IMAGE_FILE_EXTENSIONS, ) @@ -421,7 +470,7 @@ def __init__(self, parameter, parameter_name): super().__init__( parameter=parameter, parameter_name=parameter_name, - expected_file_extension=["yaml"], + expected_file_extension=[".yaml"], )