Skip to content

Commit

Permalink
change eval function
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Jul 23, 2024
1 parent 2095074 commit 9069c6d
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 23 deletions.
76 changes: 62 additions & 14 deletions src/otx/core/data/transform_libs/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from __future__ import annotations

import ast
import copy
import io
import itertools
import math
import operator
import typing
from inspect import isclass
from pathlib import Path
Expand Down Expand Up @@ -3132,19 +3134,19 @@ def generate(cls, config: SubsetConfig) -> Compose:
def _eval_input_size(cls, cfg_transform: dict[str, Any], input_size: int | tuple[int, int] | None) -> None:
"""Evaluate the input_size and replace the placeholder in the init_args.
Input size should be specified as ^{input_size}. (e.g. ^{input_size} * 0.5)
Built-in eval function is used for evaluation. So, everything `eval` function can evaluate is available.
The function decides to pass tuple type or int type based on the type hint of the argument.
So, Please make sure that the type hint is correct.
Input size should be specified as ^{input_size}. (e.g. ^{input_size} * 0.5)
Only simple multiplication or division evaluation is supported. For example,
^{input_size} * -0.5 => supported
^{input_size} * 2.1 / 3 => supported
^{input_size} + 1 => not supported
The function decides to pass tuple type or int type based on the type hint of the argument.
float point values are rounded to int.
"""
if input_size is None:
return
if isinstance(input_size, int):
input_size = (input_size, input_size)
else:
input_size = tuple(input_size)
_input_size: tuple[int, int] = (input_size, input_size) if isinstance(input_size, int) else tuple(input_size) # type: ignore[assignment]

def check_type(value, expected_type) -> bool:
def check_type(value: Any, expected_type: Any) -> bool: # noqa: ANN401
try:
typeguard.check_type(value, expected_type)
except typeguard.TypeCheckError:
Expand All @@ -3159,16 +3161,62 @@ def check_type(value, expected_type) -> bool:
model_cls = get_obj_from_str(cfg_transform["class_path"])

available_types = typing.get_type_hints(model_cls.__init__).get(key)
if available_types is None or check_type(input_size, available_types): # pass tuple[int, int]
cfg_transform["init_args"][key] = tuple(
eval(val.replace("^{input_size}", "np.array(input_size)")).round().astype(np.int32).tolist()
if available_types is None or check_type(_input_size, available_types): # pass tuple[int, int]
cfg_transform["init_args"][key] = cls._safe_eval(
val.replace("^{input_size}", f"({','.join(str(val) for val in _input_size)})"),
)
elif check_type(input_size[0], available_types): # pass int
cfg_transform["init_args"][key] = round(eval(val.replace("^{input_size}", "input_size[0]")))
elif check_type(_input_size[0], available_types): # pass int
cfg_transform["init_args"][key] = cls._safe_eval(val.replace("^{input_size}", str(_input_size[0])))
else:
msg = f"{key} argument should be able to get int or tuple[int, int], but it can get {available_types}"
raise RuntimeError(msg)

@classmethod
def _safe_eval(cls, str_to_eval: str) -> tuple[int, ...] | int:
"""Safe eval function for _eval_input_size.
The function is implemented for `_eval_input_size`, so implementation is aligned to it as below
- Only multiplication or division evaluation are supported.
- Only constant and tuple can be operand.
- tuple is changed to numpy array before evaluation.
- result value is rounded to int.
"""
bin_ops = {
ast.Mult: operator.mul,
ast.Div: operator.truediv,
}

un_ops = {
ast.USub: operator.neg,
ast.UAdd: operator.pos,
}

available_ops = tuple(bin_ops) + tuple(un_ops) + (ast.BinOp, ast.UnaryOp)

tree = ast.parse(str_to_eval, mode="eval")

def _eval(node: Any) -> Any: # noqa: ANN401
if isinstance(node, ast.Expression):
return _eval(node.body)
if isinstance(node, ast.Constant):
return node.value
if isinstance(node, ast.Tuple):
return np.array([_eval(val) for val in node.elts])
if isinstance(node, ast.BinOp) and type(node.op) in bin_ops:
left = _eval(node.left)
right = _eval(node.right)
return bin_ops[type(node.op)](left, right)
if isinstance(node, ast.UnaryOp) and type(node.op) in un_ops:
operand = _eval(node.operand) if isinstance(node.operand, available_ops) else node.operand.value
return un_ops[type(node.op)](operand) # type: ignore[operator]
msg = f"Bad syntax, {type(node)}. Available operations for calcualting input size are {available_ops}"
raise SyntaxError(msg)

ret = _eval(tree)
if isinstance(ret, np.ndarray):
return tuple(ret.round().astype(np.int32).tolist())
return round(ret)

@classmethod
def _dispatch_transform(cls, cfg_transform: DictConfig | dict | tvt_v2.Transform) -> tvt_v2.Transform:
if isinstance(cfg_transform, (DictConfig, dict)):
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import importlib
from collections import defaultdict
from multiprocessing import cpu_count
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import torch
from datumaro.components.annotation import AnnotationType, LabelCategories
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_idx_list_per_classes(dm_dataset: DmDataset, use_string_label: bool = Fal
return stats


def get_obj_from_str(obj_path: str) -> Any:
def get_obj_from_str(obj_path: str) -> Any: # noqa: ANN401
"""Get object from import format string."""
module_name, obj_name = obj_path.rsplit(".", 1)
module = importlib.import_module(module_name)
Expand Down
25 changes: 19 additions & 6 deletions tests/unit/core/data/test_transform_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,25 +186,25 @@ def test_transform(
item = dataset[0]
assert isinstance(item, data_entity_cls)

@pytest.fixture
@pytest.fixture()
def fxt_config_w_input_size(self) -> list[dict[str, Any]]:
cfg = f"""
cfg = """
input_size:
- 224
- 224
transforms:
- class_path: otx.core.data.transform_libs.torchvision.ResizetoLongestEdge
init_args:
size: ^{{input_size}} * 2
size: ^{input_size} * 2
- class_path: otx.core.data.transform_libs.torchvision.RandomResize
init_args:
scale: ^{{input_size}} * 0.5
scale: ^{input_size} * 0.5
- class_path: otx.core.data.transform_libs.torchvision.RandomCrop
init_args:
crop_size: ^{{input_size}}
crop_size: ^{input_size}
- class_path: otx.core.data.transform_libs.torchvision.RandomResize
init_args:
scale: ^{{input_size}} * 1.1
scale: ^{input_size} * 1.1
"""
return OmegaConf.create(cfg)

Expand All @@ -216,6 +216,19 @@ def test_eval_input_size(self, fxt_config_w_input_size):
assert transform.transforms[2].crop_size == (224, 224) # RandomCrop gets sequence of integer
assert transform.transforms[3].scale == (round(224 * 1.1), round(224 * 1.1)) # check round

def test_safe_eval(self):
assert TorchVisionTransformLib._safe_eval("2") == 2
assert TorchVisionTransformLib._safe_eval("(2, 3)") == (2, 3)
assert TorchVisionTransformLib._safe_eval("2*3") == 6
assert TorchVisionTransformLib._safe_eval("(2, 3) *3") == (6, 9)
assert TorchVisionTransformLib._safe_eval("(5, 5) / 2") == (2, 2)
assert TorchVisionTransformLib._safe_eval("(10, 11) * -0.5") == (-5, -6)

@pytest.mark.parametrize("input_str", ["1+1", "1+-5", "rm fake", "hoho", "DecordDecode()"])
def test_safe_eval_wrong_value(self, input_str):
with pytest.raises(SyntaxError):
assert TorchVisionTransformLib._safe_eval(input_str)

@pytest.fixture(params=["RGB", "BGR"])
def fxt_image_color_channel(self, request) -> ImageColorChannel:
return ImageColorChannel(request.param)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/core/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
get_adaptive_num_workers,
get_idx_list_per_classes,
get_mean_std_from_data_processing,
get_obj_from_str,
is_ckpt_for_finetuning,
is_ckpt_from_otx_v1,
get_obj_from_str,
)


Expand Down Expand Up @@ -115,6 +115,7 @@ def test_get_idx_list_per_classes(fxt_dm_dataset):
expected_result["1"] = list(range(100, 108))
assert result == expected_result


def test_get_obj_from_str():
obj_path = "otx.core.utils.utils.get_mean_std_from_data_processing"
obj = get_obj_from_str(obj_path)
Expand Down

0 comments on commit 9069c6d

Please sign in to comment.