diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py index d45b4fb84b27..0b48dfc2b0e6 100644 --- a/python/tvm/tir/schedule/_type_checker.py +++ b/python/tvm/tir/schedule/_type_checker.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Type checking functionality""" +import collections +import collections.abc import functools import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union @@ -26,6 +28,7 @@ def _is_none_type(type_: Any) -> bool: if hasattr(typing, "_GenericAlias"): + # For python versions 3.7 onward, check the __origin__ attribute. class _Subtype: @staticmethod @@ -71,7 +74,15 @@ def union(type_: Any) -> Optional[List[type]]: return list(subtypes) return None + @staticmethod + def callable(type_: Any) -> Optional[List[type]]: + if _Subtype._origin(type_) is collections.abc.Callable: + subtypes = type_.__args__ + return subtypes + return None + elif hasattr(typing, "_Union"): + # For python 3.6 and below, check the __name__ attribute, or CallableMeta. class _Subtype: # type: ignore @staticmethod @@ -114,6 +125,13 @@ def union(type_: Any) -> Optional[List[type]]: return list(subtypes) return None + @staticmethod + def callable(type_: Any) -> Optional[List[type]]: + if isinstance(type_, typing.CallableMeta): # type: ignore # pylint: disable=no-member,protected-access + subtypes = type_.__args__ + return subtypes + return None + def _dispatcher(type_: Any) -> Tuple[str, List[type]]: if _is_none_type(type_): @@ -139,12 +157,27 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]: if subtype is not None: return "union", subtype + subtype = _Subtype.callable(type_) + if subtype is not None: + return "callable", subtype + return "atomic", [type_] +def callable_str(subtypes): + if subtypes: + *arg_types, return_type = subtypes + arg_str = ", ".join(_type2str(arg_type) for arg_type in arg_types) + return_type_str = _type2str(return_type) + return f"Callable[[{arg_str}], {return_type_str}]" + else: + return "Callable" + + _TYPE2STR: Dict[Any, Callable] = { "none": lambda: "None", "atomic": lambda t: str(t.__name__), + "callable": callable_str, "list": lambda t: f"List[{_type2str(t)}]", "dict": lambda k, v: f"Dict[{_type2str(k)}, {_type2str(v)}]", "tuple": lambda *t: f"Tuple[{', '.join([_type2str(x) for x in t])}]", @@ -188,6 +221,12 @@ def _type_check_none(v: Any, name: str) -> Optional[str]: def _type_check_atomic(v: Any, name: str, type_: Any) -> Optional[str]: return None if isinstance(v, type_) else _type_check_err(v, name, type_) + def _type_check_callable(v: Any, name: str, *_subtypes: Any) -> Optional[str]: + # Current implementation only validates that the argument is + # callable, and doesn't validate the arguments accepted by the + # callable, if any. + return None if callable(v) else _type_check_err(v, name, Callable) + def _type_check_list(v: List[Any], name: str, type_: Any) -> Optional[str]: if not isinstance(v, (list, tuple)): return _type_check_err(v, name, list) @@ -234,6 +273,7 @@ def _type_check_union(v: Any, name: str, *types: Any) -> Optional[str]: return { "none": _type_check_none, "atomic": _type_check_atomic, + "callable": _type_check_callable, "list": _type_check_list, "dict": _type_check_dict, "tuple": _type_check_tuple, diff --git a/tests/python/unittest/test_type_annotation_checker.py b/tests/python/unittest/test_type_annotation_checker.py index e84ae043d356..204c15331339 100644 --- a/tests/python/unittest/test_type_annotation_checker.py +++ b/tests/python/unittest/test_type_annotation_checker.py @@ -17,13 +17,22 @@ """Test type checker based on python's type annotations""" import sys -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union, Callable import pytest +import _pytest from tvm.tir.schedule._type_checker import type_checked +def int_func(x: int) -> int: + return 2 * x + + +def str_func(x: str) -> str: + return 2 * x + + test_cases = [ { "type_annotation": int, @@ -90,30 +99,71 @@ None, ], }, + { + "type_annotation": Callable, + "positive_cases": [str_func, int_func], + "negative_cases": [ + None, + "x", + 42, + ], + }, + { + "type_annotation": Callable[[int], int], + "positive_cases": [int_func], + "negative_cases": [ + None, + "x", + 42, + pytest.param( + str_func, + marks=pytest.mark.xfail( + reason="Signature of Callable arguments not currently checked" + ), + ), + ], + }, ] -positive_cases = [ - (config["type_annotation"], case) for config in test_cases for case in config["positive_cases"] -] - -negative_cases = [ - (config["type_annotation"], case) for config in test_cases for case in config["negative_cases"] -] +def make_parametrization(type_annotation, case): + if isinstance(case, _pytest.mark.structures.ParameterSet): + marks = case.marks + (case,) = case.values + else: + marks = [] -def format_name(type_annotation, case): try: - name = type_annotation.__name__ + annotation_name = type_annotation.__name__ except AttributeError: - name = str(type_annotation).replace("typing.", "") + annotation_name = str(type_annotation).replace("typing.", "") + + if hasattr(case, "__name__"): + case_name = case.__name__ + else: + case_name = str(case) - return f"{name}_{case}" + name = f"{annotation_name}, {case_name}" + + return pytest.param(type_annotation, case, marks=marks, id=name) + + +positive_cases = [ + make_parametrization(config["type_annotation"], case) + for config in test_cases + for case in config["positive_cases"] +] + +negative_cases = [ + make_parametrization(config["type_annotation"], case) + for config in test_cases + for case in config["negative_cases"] +] @pytest.mark.parametrize( ["type_annotation", "case"], positive_cases, - ids=[format_name(t, c) for t, c in positive_cases], ) def test_matches_type(type_annotation, case): @type_checked @@ -126,7 +176,6 @@ def func(_: type_annotation): @pytest.mark.parametrize( ["type_annotation", "case"], negative_cases, - ids=[format_name(t, c) for t, c in negative_cases], ) def test_not_matches(type_annotation, case): @type_checked