From 9069c6db1ef81ad1b41740e7169c16fc50e9a02a Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Tue, 23 Jul 2024 18:09:08 +0900 Subject: [PATCH] change eval function --- .../core/data/transform_libs/torchvision.py | 76 +++++++++++++++---- src/otx/core/utils/utils.py | 4 +- tests/unit/core/data/test_transform_libs.py | 25 ++++-- tests/unit/core/utils/test_utils.py | 3 +- 4 files changed, 85 insertions(+), 23 deletions(-) diff --git a/src/otx/core/data/transform_libs/torchvision.py b/src/otx/core/data/transform_libs/torchvision.py index 85be4d68a5e..4ad75fcae06 100644 --- a/src/otx/core/data/transform_libs/torchvision.py +++ b/src/otx/core/data/transform_libs/torchvision.py @@ -5,10 +5,12 @@ from __future__ import annotations +import ast import copy import io import itertools import math +import operator import typing from inspect import isclass from pathlib import Path @@ -3132,19 +3134,19 @@ def generate(cls, config: SubsetConfig) -> Compose: def _eval_input_size(cls, cfg_transform: dict[str, Any], input_size: int | tuple[int, int] | None) -> None: """Evaluate the input_size and replace the placeholder in the init_args. - Input size should be specified as ^{input_size}. (e.g. ^{input_size} * 0.5) - Built-in eval function is used for evaluation. So, everything `eval` function can evaluate is available. - The function decides to pass tuple type or int type based on the type hint of the argument. - So, Please make sure that the type hint is correct. + Input size should be specified as ^{input_size}. (e.g. ^{input_size} * 0.5) + Only simple multiplication or division evaluation is supported. For example, + ^{input_size} * -0.5 => supported + ^{input_size} * 2.1 / 3 => supported + ^{input_size} + 1 => not supported + The function decides to pass tuple type or int type based on the type hint of the argument. + float point values are rounded to int. """ if input_size is None: return - if isinstance(input_size, int): - input_size = (input_size, input_size) - else: - input_size = tuple(input_size) + _input_size: tuple[int, int] = (input_size, input_size) if isinstance(input_size, int) else tuple(input_size) # type: ignore[assignment] - def check_type(value, expected_type) -> bool: + def check_type(value: Any, expected_type: Any) -> bool: # noqa: ANN401 try: typeguard.check_type(value, expected_type) except typeguard.TypeCheckError: @@ -3159,16 +3161,62 @@ def check_type(value, expected_type) -> bool: model_cls = get_obj_from_str(cfg_transform["class_path"]) available_types = typing.get_type_hints(model_cls.__init__).get(key) - if available_types is None or check_type(input_size, available_types): # pass tuple[int, int] - cfg_transform["init_args"][key] = tuple( - eval(val.replace("^{input_size}", "np.array(input_size)")).round().astype(np.int32).tolist() + if available_types is None or check_type(_input_size, available_types): # pass tuple[int, int] + cfg_transform["init_args"][key] = cls._safe_eval( + val.replace("^{input_size}", f"({','.join(str(val) for val in _input_size)})"), ) - elif check_type(input_size[0], available_types): # pass int - cfg_transform["init_args"][key] = round(eval(val.replace("^{input_size}", "input_size[0]"))) + elif check_type(_input_size[0], available_types): # pass int + cfg_transform["init_args"][key] = cls._safe_eval(val.replace("^{input_size}", str(_input_size[0]))) else: msg = f"{key} argument should be able to get int or tuple[int, int], but it can get {available_types}" raise RuntimeError(msg) + @classmethod + def _safe_eval(cls, str_to_eval: str) -> tuple[int, ...] | int: + """Safe eval function for _eval_input_size. + + The function is implemented for `_eval_input_size`, so implementation is aligned to it as below + - Only multiplication or division evaluation are supported. + - Only constant and tuple can be operand. + - tuple is changed to numpy array before evaluation. + - result value is rounded to int. + """ + bin_ops = { + ast.Mult: operator.mul, + ast.Div: operator.truediv, + } + + un_ops = { + ast.USub: operator.neg, + ast.UAdd: operator.pos, + } + + available_ops = tuple(bin_ops) + tuple(un_ops) + (ast.BinOp, ast.UnaryOp) + + tree = ast.parse(str_to_eval, mode="eval") + + def _eval(node: Any) -> Any: # noqa: ANN401 + if isinstance(node, ast.Expression): + return _eval(node.body) + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.Tuple): + return np.array([_eval(val) for val in node.elts]) + if isinstance(node, ast.BinOp) and type(node.op) in bin_ops: + left = _eval(node.left) + right = _eval(node.right) + return bin_ops[type(node.op)](left, right) + if isinstance(node, ast.UnaryOp) and type(node.op) in un_ops: + operand = _eval(node.operand) if isinstance(node.operand, available_ops) else node.operand.value + return un_ops[type(node.op)](operand) # type: ignore[operator] + msg = f"Bad syntax, {type(node)}. Available operations for calcualting input size are {available_ops}" + raise SyntaxError(msg) + + ret = _eval(tree) + if isinstance(ret, np.ndarray): + return tuple(ret.round().astype(np.int32).tolist()) + return round(ret) + @classmethod def _dispatch_transform(cls, cfg_transform: DictConfig | dict | tvt_v2.Transform) -> tvt_v2.Transform: if isinstance(cfg_transform, (DictConfig, dict)): diff --git a/src/otx/core/utils/utils.py b/src/otx/core/utils/utils.py index b1c6c2b0eb9..b1a70e6a852 100644 --- a/src/otx/core/utils/utils.py +++ b/src/otx/core/utils/utils.py @@ -8,7 +8,7 @@ import importlib from collections import defaultdict from multiprocessing import cpu_count -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch from datumaro.components.annotation import AnnotationType, LabelCategories @@ -87,7 +87,7 @@ def get_idx_list_per_classes(dm_dataset: DmDataset, use_string_label: bool = Fal return stats -def get_obj_from_str(obj_path: str) -> Any: +def get_obj_from_str(obj_path: str) -> Any: # noqa: ANN401 """Get object from import format string.""" module_name, obj_name = obj_path.rsplit(".", 1) module = importlib.import_module(module_name) diff --git a/tests/unit/core/data/test_transform_libs.py b/tests/unit/core/data/test_transform_libs.py index e395deb6d8f..49c8a43b9b1 100644 --- a/tests/unit/core/data/test_transform_libs.py +++ b/tests/unit/core/data/test_transform_libs.py @@ -186,25 +186,25 @@ def test_transform( item = dataset[0] assert isinstance(item, data_entity_cls) - @pytest.fixture + @pytest.fixture() def fxt_config_w_input_size(self) -> list[dict[str, Any]]: - cfg = f""" + cfg = """ input_size: - 224 - 224 transforms: - class_path: otx.core.data.transform_libs.torchvision.ResizetoLongestEdge init_args: - size: ^{{input_size}} * 2 + size: ^{input_size} * 2 - class_path: otx.core.data.transform_libs.torchvision.RandomResize init_args: - scale: ^{{input_size}} * 0.5 + scale: ^{input_size} * 0.5 - class_path: otx.core.data.transform_libs.torchvision.RandomCrop init_args: - crop_size: ^{{input_size}} + crop_size: ^{input_size} - class_path: otx.core.data.transform_libs.torchvision.RandomResize init_args: - scale: ^{{input_size}} * 1.1 + scale: ^{input_size} * 1.1 """ return OmegaConf.create(cfg) @@ -216,6 +216,19 @@ def test_eval_input_size(self, fxt_config_w_input_size): assert transform.transforms[2].crop_size == (224, 224) # RandomCrop gets sequence of integer assert transform.transforms[3].scale == (round(224 * 1.1), round(224 * 1.1)) # check round + def test_safe_eval(self): + assert TorchVisionTransformLib._safe_eval("2") == 2 + assert TorchVisionTransformLib._safe_eval("(2, 3)") == (2, 3) + assert TorchVisionTransformLib._safe_eval("2*3") == 6 + assert TorchVisionTransformLib._safe_eval("(2, 3) *3") == (6, 9) + assert TorchVisionTransformLib._safe_eval("(5, 5) / 2") == (2, 2) + assert TorchVisionTransformLib._safe_eval("(10, 11) * -0.5") == (-5, -6) + + @pytest.mark.parametrize("input_str", ["1+1", "1+-5", "rm fake", "hoho", "DecordDecode()"]) + def test_safe_eval_wrong_value(self, input_str): + with pytest.raises(SyntaxError): + assert TorchVisionTransformLib._safe_eval(input_str) + @pytest.fixture(params=["RGB", "BGR"]) def fxt_image_color_channel(self, request) -> ImageColorChannel: return ImageColorChannel(request.param) diff --git a/tests/unit/core/utils/test_utils.py b/tests/unit/core/utils/test_utils.py index 80ac9d0bce5..931c361d1b5 100644 --- a/tests/unit/core/utils/test_utils.py +++ b/tests/unit/core/utils/test_utils.py @@ -9,9 +9,9 @@ get_adaptive_num_workers, get_idx_list_per_classes, get_mean_std_from_data_processing, + get_obj_from_str, is_ckpt_for_finetuning, is_ckpt_from_otx_v1, - get_obj_from_str, ) @@ -115,6 +115,7 @@ def test_get_idx_list_per_classes(fxt_dm_dataset): expected_result["1"] = list(range(100, 108)) assert result == expected_result + def test_get_obj_from_str(): obj_path = "otx.core.utils.utils.get_mean_std_from_data_processing" obj = get_obj_from_str(obj_path)