From 0c50de6144d99eedb402a8e85eb8187098f8c26f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 27 Oct 2024 00:53:03 +0300 Subject: [PATCH 01/11] Declared Python 3.13 support --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index cf93b98..4228242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] requires-python = ">= 3.8" dependencies = [ From d812f2eba9f5e898544eb4b3e597f8c38b0952e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 27 Oct 2024 00:57:49 +0300 Subject: [PATCH 02/11] Migrated to native tox TOML configuration --- pyproject.toml | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4228242..226190d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,19 +98,15 @@ strict = true pretty = true [tool.tox] -legacy_tox_ini = """ -[tox] -envlist = pypy3, py38, py39, py310, py311, py312, py313 +env_list = ["py38", "py39", "py310", "py311", "py312", "py313"] skip_missing_interpreters = true -minversion = 4.0 -[testenv] -extras = test -commands = coverage run -m pytest {posargs} -package = editable +[tool.tox.env_run_base] +commands = [["coverage", "run", "-m", "pytest", { replace = "posargs", extend = true }]] +package = "editable" +extras = ["test"] -[testenv:docs] -extras = doc -package = editable -commands = sphinx-build -W -n docs build/sphinx -""" +[tool.tox.env.docs] +depends = [] +extras = ["doc"] +commands = [["sphinx-build", "-W", "-n", "docs", "build/sphinx"]] From afad2c7b6be830900776922bb39f9346c2e77f6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 27 Oct 2024 00:58:35 +0300 Subject: [PATCH 03/11] Sorted the Ruff rules alphabetically --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 226190d..5df4ab0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,11 +81,11 @@ src = ["src"] [tool.ruff.lint] extend-select = [ - "W", # pycodestyle warnings + "B0", # flake8-bugbear "I", # isort "PGH", # pygrep-hooks "UP", # pyupgrade - "B0", # flake8-bugbear + "W", # pycodestyle warnings ] ignore = [ "S307", From b72794dffe403254881ac0c327155357c43ccebf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 27 Oct 2024 12:46:44 +0200 Subject: [PATCH 04/11] Added proper Protocol method signature checking (#496) It's not good enough to pretend we can use `check_callable()` to check method signature compatibility. Fixes #465. --- docs/features.rst | 9 +- docs/versionhistory.rst | 4 + src/typeguard/_checkers.py | 250 ++++++++++++++++++++++----------- tests/__init__.py | 15 -- tests/test_checkers.py | 275 ++++++++++++++++++++++++++----------- 5 files changed, 375 insertions(+), 178 deletions(-) diff --git a/docs/features.rst b/docs/features.rst index 3cd1e34..3141456 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -65,13 +65,8 @@ As of version 4.3.0, Typeguard can check instances and classes against Protocols regardless of whether they were annotated with :func:`@runtime_checkable `. -There are several limitations on the checks performed, however: - -* For non-callable members, only presence is checked for; no type compatibility checks - are performed -* For methods, only the number of positional arguments are checked against, so any added - keyword-only arguments without defaults don't currently trip the checker -* Likewise, argument types are not checked for compatibility +The only current limitation is that argument annotations are not checked for +compatibility, however this should be covered by static type checkers pretty well. Special considerations for ``if TYPE_CHECKING:`` ------------------------------------------------ diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 64e33d8..de7c154 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -6,8 +6,12 @@ This library adheres to **UNRELEASED** +- Added proper checking for method signatures in protocol checks + (`#465 `_) - Fixed basic support for intersection protocols (`#490 `_; PR by @antonagestam) +- Fixed protocol checks running against the class of an instance and not the instance + itself (this produced wrong results for non-method member checks) **4.3.0** (2024-05-27) diff --git a/src/typeguard/_checkers.py b/src/typeguard/_checkers.py index 52ec2b8..44ca34b 100644 --- a/src/typeguard/_checkers.py +++ b/src/typeguard/_checkers.py @@ -9,6 +9,7 @@ from enum import Enum from inspect import Parameter, isclass, isfunction from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase +from itertools import zip_longest from textwrap import indent from typing import ( IO, @@ -32,7 +33,6 @@ Union, ) from unittest.mock import Mock -from weakref import WeakKeyDictionary import typing_extensions @@ -86,10 +86,6 @@ if sys.version_info >= (3, 9): generic_alias_types += (types.GenericAlias,) -protocol_check_cache: WeakKeyDictionary[ - type[Any], dict[type[Any], TypeCheckError | None] -] = WeakKeyDictionary() - # Sentinel _missing = object() @@ -638,96 +634,196 @@ def check_io( raise TypeCheckError("is not an I/O object") -def check_protocol( - value: Any, - origin_type: Any, - args: tuple[Any, ...], - memo: TypeCheckMemo, +def check_signature_compatible( + subject_callable: Callable[..., Any], protocol: type, attrname: str ) -> None: - subject: type[Any] = value if isclass(value) else type(value) + subject_sig = inspect.signature(subject_callable) + protocol_sig = inspect.signature(getattr(protocol, attrname)) + protocol_type: typing.Literal["instance", "class", "static"] = "instance" + subject_type: typing.Literal["instance", "class", "static"] = "instance" + + # Check if the protocol-side method is a class method or static method + if attrname in protocol.__dict__: + descriptor = protocol.__dict__[attrname] + if isinstance(descriptor, staticmethod): + protocol_type = "static" + elif isinstance(descriptor, classmethod): + protocol_type = "class" + + # Check if the subject-side method is a class method or static method + if inspect.ismethod(subject_callable) and inspect.isclass( + subject_callable.__self__ + ): + subject_type = "class" + elif not hasattr(subject_callable, "__self__"): + subject_type = "static" - if subject in protocol_check_cache: - result_map = protocol_check_cache[subject] - if origin_type in result_map: - if exc := result_map[origin_type]: - raise exc - else: - return + if protocol_type == "instance" and subject_type != "instance": + raise TypeCheckError( + f"should be an instance method but it's a {subject_type} method" + ) + elif protocol_type != "instance" and subject_type == "instance": + raise TypeCheckError( + f"should be a {protocol_type} method but it's an instance method" + ) - expected_methods: dict[str, tuple[Any, Any]] = {} - expected_noncallable_members: dict[str, Any] = {} - origin_annotations = typing.get_type_hints(origin_type) + expected_varargs = any( + param + for param in protocol_sig.parameters.values() + if param.kind is Parameter.VAR_POSITIONAL + ) + has_varargs = any( + param + for param in subject_sig.parameters.values() + if param.kind is Parameter.VAR_POSITIONAL + ) + if expected_varargs and not has_varargs: + raise TypeCheckError("should accept variable positional arguments but doesn't") + + protocol_has_varkwargs = any( + param + for param in protocol_sig.parameters.values() + if param.kind is Parameter.VAR_KEYWORD + ) + subject_has_varkwargs = any( + param + for param in subject_sig.parameters.values() + if param.kind is Parameter.VAR_KEYWORD + ) + if protocol_has_varkwargs and not subject_has_varkwargs: + raise TypeCheckError("should accept variable keyword arguments but doesn't") + + # Check that the callable has at least the expect amount of positional-only + # arguments (and no extra positional-only arguments without default values) + if not has_varargs: + protocol_args = [ + param + for param in protocol_sig.parameters.values() + if param.kind + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + ] + subject_args = [ + param + for param in subject_sig.parameters.values() + if param.kind + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + ] + + # Remove the "self" parameter from the protocol arguments to match + if protocol_type == "instance": + protocol_args.pop(0) + + for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args): + if protocol_arg is None: + if subject_arg.default is Parameter.empty: + raise TypeCheckError("has too many mandatory positional arguments") + + break + + if subject_arg is None: + raise TypeCheckError("has too few positional arguments") + + if ( + protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD + and subject_arg.kind is Parameter.POSITIONAL_ONLY + ): + raise TypeCheckError( + f"has an argument ({subject_arg.name}) that should not be " + f"positional-only" + ) + + if ( + protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD + and protocol_arg.name != subject_arg.name + ): + raise TypeCheckError( + f"has a positional argument ({subject_arg.name}) that should be " + f"named {protocol_arg.name!r} at this position" + ) - for attrname in typing_extensions.get_protocol_members(origin_type): - member = getattr(origin_type, attrname, None) - - if callable(member): - signature = inspect.signature(member) - argtypes = [ - (p.annotation if p.annotation is not Parameter.empty else Any) - for p in signature.parameters.values() - if p.kind is not Parameter.KEYWORD_ONLY - ] or Ellipsis - return_annotation = ( - signature.return_annotation - if signature.return_annotation is not Parameter.empty - else Any + protocol_kwonlyargs = { + param.name: param + for param in protocol_sig.parameters.values() + if param.kind is Parameter.KEYWORD_ONLY + } + subject_kwonlyargs = { + param.name: param + for param in subject_sig.parameters.values() + if param.kind is Parameter.KEYWORD_ONLY + } + if not subject_has_varkwargs: + # Check that the signature has at least the required keyword-only arguments, and + # no extra mandatory keyword-only arguments + if missing_kwonlyargs := [ + param.name + for param in protocol_kwonlyargs.values() + if param.name not in subject_kwonlyargs + ]: + raise TypeCheckError( + "is missing keyword-only arguments: " + ", ".join(missing_kwonlyargs) ) - expected_methods[attrname] = argtypes, return_annotation - else: - try: - expected_noncallable_members[attrname] = origin_annotations[attrname] - except KeyError: - expected_noncallable_members[attrname] = member - subject_annotations = typing.get_type_hints(subject) + if not protocol_has_varkwargs: + if extra_kwonlyargs := [ + param.name + for param in subject_kwonlyargs.values() + if param.default is Parameter.empty + and param.name not in protocol_kwonlyargs + ]: + raise TypeCheckError( + "has mandatory keyword-only arguments not present in the protocol: " + + ", ".join(extra_kwonlyargs) + ) - # Check that all required methods are present and their signatures are compatible - result_map = protocol_check_cache.setdefault(subject, {}) - try: - for attrname, callable_args in expected_methods.items(): + +def check_protocol( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + origin_annotations = typing.get_type_hints(origin_type) + for attrname in sorted(typing_extensions.get_protocol_members(origin_type)): + if (annotation := origin_annotations.get(attrname)) is not None: try: - method = getattr(subject, attrname) + subject_member = getattr(value, attrname) except AttributeError: - if attrname in subject_annotations: - raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because its {attrname!r} attribute is not a method" - ) from None - else: - raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because it has no method named {attrname!r}" - ) from None - - if not callable(method): raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because its {attrname!r} attribute is not a callable" - ) + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because it has no attribute named {attrname!r}" + ) from None - # TODO: raise exception on added keyword-only arguments without defaults try: - check_callable(method, Callable, callable_args, memo) + check_type_internal(subject_member, annotation, memo) except TypeCheckError as exc: raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because its {attrname!r} method {exc}" + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because its {attrname!r} attribute {exc}" + ) from None + elif callable(getattr(origin_type, attrname)): + try: + subject_member = getattr(value, attrname) + except AttributeError: + raise TypeCheckError( + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because it has no method named {attrname!r}" ) from None - # Check that all required non-callable members are present - for attrname in expected_noncallable_members: - # TODO: implement assignability checks for non-callable members - if attrname not in subject_annotations and not hasattr(subject, attrname): + if not callable(subject_member): raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because it has no attribute named {attrname!r}" + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because its {attrname!r} attribute is not a callable" ) - except TypeCheckError as exc: - result_map[origin_type] = exc - raise - else: - result_map[origin_type] = None + + # TODO: implement assignability checks for parameter and return value + # annotations + try: + check_signature_compatible(subject_member, origin_type, attrname) + except TypeCheckError as exc: + raise TypeCheckError( + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because its {attrname!r} method {exc}" + ) from None def check_byteslike( diff --git a/tests/__init__.py b/tests/__init__.py index f28f2c2..b48bd69 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -6,10 +6,8 @@ List, NamedTuple, NewType, - Protocol, TypeVar, Union, - runtime_checkable, ) T_Foo = TypeVar("T_Foo") @@ -44,16 +42,3 @@ class Parent: class Child(Parent): def method(self, a: int) -> None: pass - - -class StaticProtocol(Protocol): - member: int - - def meth(self, x: str) -> None: ... - - -@runtime_checkable -class RuntimeProtocol(Protocol): - member: int - - def meth(self, x: str) -> None: ... diff --git a/tests/test_checkers.py b/tests/test_checkers.py index d9237a9..f780964 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -54,8 +54,6 @@ Employee, JSONType, Parent, - RuntimeProtocol, - StaticProtocol, TChild, TIntStr, TParent, @@ -1078,119 +1076,238 @@ def test_raises_for_non_member(self, subject: object, predicate_type: type) -> N check_type(subject, predicate_type) -@pytest.mark.parametrize( - "instantiate, annotation", - [ - pytest.param(True, RuntimeProtocol, id="instance_runtime"), - pytest.param(False, Type[RuntimeProtocol], id="class_runtime"), - pytest.param(True, StaticProtocol, id="instance_static"), - pytest.param(False, Type[StaticProtocol], id="class_static"), - ], -) class TestProtocol: - def test_member_defaultval(self, instantiate, annotation): + def test_success(self, typing_provider: Any) -> None: + class MyProtocol(Protocol): + member: int + + def noargs(self) -> None: + pass + + def posonlyargs(self, a: int, b: str, /) -> None: + pass + + def posargs(self, a: int, b: str, c: float = 2.0) -> None: + pass + + def varargs(self, *args: Any) -> None: + pass + + def varkwargs(self, **kwargs: Any) -> None: + pass + + def varbothargs(self, *args: Any, **kwargs: Any) -> None: + pass + + @staticmethod + def my_static_method(x: int, y: str) -> None: + pass + + @classmethod + def my_class_method(cls, x: int, y: str) -> None: + pass + class Foo: member = 1 - def meth(self, x: str) -> None: + def noargs(self, x: int = 1) -> None: pass - subject = Foo() if instantiate else Foo - for _ in range(2): # Makes sure that the cache is also exercised - check_type(subject, annotation) + def posonlyargs(self, a: int, b: str, c: float = 2.0, /) -> None: + pass - def test_member_annotation(self, instantiate, annotation): - class Foo: + def posargs(self, *args: Any) -> None: + pass + + def varargs(self, *args: Any, kwarg: str = "foo") -> None: + pass + + def varkwargs(self, **kwargs: Any) -> None: + pass + + def varbothargs(self, *args: Any, **kwargs: Any) -> None: + pass + + # These were intentionally reversed, as this is OK for mypy + @classmethod + def my_static_method(cls, x: int, y: str) -> None: + pass + + @staticmethod + def my_class_method(x: int, y: str) -> None: + pass + + check_type(Foo(), MyProtocol) + + @pytest.mark.parametrize("has_member", [True, False]) + def test_member_checks(self, has_member: bool) -> None: + class MyProtocol(Protocol): member: int + class Foo: + def __init__(self, member: int): + if member: + self.member = member + + if has_member: + check_type(Foo(1), MyProtocol) + else: + pytest.raises(TypeCheckError, check_type, Foo(0), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because it has no attribute named " + f"'member'" + ) + + def test_missing_method(self) -> None: + class MyProtocol(Protocol): + def meth(self) -> None: + pass + + class Foo: + pass + + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because it has no method named " + f"'meth'" + ) + + def test_too_many_posargs(self) -> None: + class MyProtocol(Protocol): + def meth(self) -> None: + pass + + class Foo: def meth(self, x: str) -> None: pass - subject = Foo() if instantiate else Foo - for _ in range(2): - check_type(subject, annotation) + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method has too " + f"many mandatory positional arguments" + ) + + def test_wrong_posarg_name(self) -> None: + class MyProtocol(Protocol): + def meth(self, x: str) -> None: + pass - def test_attribute_missing(self, instantiate, annotation): class Foo: - val = 1 + def meth(self, y: str) -> None: + pass + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + rf"^{qualified_name(Foo)} is not compatible with the " + rf"{MyProtocol.__qualname__} protocol because its 'meth' method has a " + rf"positional argument \(y\) that should be named 'x' at this position" + ) + + def test_too_few_posargs(self) -> None: + class MyProtocol(Protocol): def meth(self, x: str) -> None: pass - clsname = f"{__name__}.TestProtocol.test_attribute_missing..Foo" - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - f"{clsname} is not compatible with the (Runtime|Static)Protocol " - f"protocol because it has no attribute named 'member'" - ) + class Foo: + def meth(self) -> None: + pass + + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method has too " + f"few positional arguments" + ) + + def test_no_varargs(self) -> None: + class MyProtocol(Protocol): + def meth(self, *args: Any) -> None: + pass - def test_method_missing(self, instantiate, annotation): class Foo: - member: int + def meth(self) -> None: + pass - pattern = ( - f"{__name__}.TestProtocol.test_method_missing..Foo is not " - f"compatible with the (Runtime|Static)Protocol protocol because it has no " - f"method named 'meth'" + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method should " + f"accept variable positional arguments but doesn't" ) - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - pattern - ) - def test_attribute_is_not_method_1(self, instantiate, annotation): + def test_no_kwargs(self) -> None: + class MyProtocol(Protocol): + def meth(self, **kwargs: Any) -> None: + pass + class Foo: - member: int - meth: str + def meth(self) -> None: + pass - pattern = ( - f"{__name__}.TestProtocol.test_attribute_is_not_method_1..Foo is " - f"not compatible with the (Runtime|Static)Protocol protocol because its " - f"'meth' attribute is not a method" + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method should " + f"accept variable keyword arguments but doesn't" ) - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - pattern - ) - def test_attribute_is_not_method_2(self, instantiate, annotation): + def test_missing_kwarg(self) -> None: + class MyProtocol(Protocol): + def meth(self, *, x: str) -> None: + pass + class Foo: - member: int - meth = "foo" + def meth(self) -> None: + pass - pattern = ( - f"{__name__}.TestProtocol.test_attribute_is_not_method_2..Foo is " - f"not compatible with the (Runtime|Static)Protocol protocol because its " - f"'meth' attribute is not a callable" + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method is " + f"missing keyword-only arguments: x" ) - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - pattern - ) - def test_method_signature_mismatch(self, instantiate, annotation): + def test_extra_kwarg(self) -> None: + class MyProtocol(Protocol): + def meth(self) -> None: + pass + class Foo: - member: int + def meth(self, *, x: str) -> None: + pass + + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method has " + f"mandatory keyword-only arguments not present in the protocol: x" + ) + + def test_instance_staticmethod_mismatch(self) -> None: + class MyProtocol(Protocol): + @staticmethod + def meth() -> None: + pass - def meth(self, x: str, y: int) -> None: + class Foo: + def meth(self) -> None: pass - pattern = ( - rf"(class )?{__name__}.TestProtocol.test_method_signature_mismatch." - rf".Foo is not compatible with the (Runtime|Static)Protocol " - rf"protocol because its 'meth' method has too many mandatory positional " - rf"arguments in its declaration; expected 2 but 3 mandatory positional " - rf"argument\(s\) declared" + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method should " + f"be a static method but it's an instance method" + ) + + def test_instance_classmethod_mismatch(self) -> None: + class MyProtocol(Protocol): + @classmethod + def meth(cls) -> None: + pass + + class Foo: + def meth(self) -> None: + pass + + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method should " + f"be a class method but it's an instance method" ) - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - pattern - ) class TestRecursiveType: From efa1166c85be9a1280090fea9c287b5e4e9f3830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 27 Oct 2024 12:55:53 +0200 Subject: [PATCH 05/11] Added release date --- docs/versionhistory.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index de7c154..7ab922c 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -4,7 +4,7 @@ Version history This library adheres to `Semantic Versioning 2.0 `_. -**UNRELEASED** +**4.4.0** (2024-10-27) - Added proper checking for method signatures in protocol checks (`#465 `_) From c7f5a4fe996f67c24496aa9608a830fc3dcd6809 Mon Sep 17 00:00:00 2001 From: Vasily Zakharov Date: Tue, 29 Oct 2024 17:25:48 +0300 Subject: [PATCH 06/11] Further improvement on typeguard_ignore() annotation (#497) --- docs/versionhistory.rst | 3 +++ src/typeguard/_decorators.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 7ab922c..cb1563c 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -4,6 +4,9 @@ Version history This library adheres to `Semantic Versioning 2.0 `_. +- Changed the signature of ``typeguard_ignore()`` to be compatible with + ``typing.no_type_check()`` (PR by @jolaf) + **4.4.0** (2024-10-27) - Added proper checking for method signatures in protocol checks diff --git a/src/typeguard/_decorators.py b/src/typeguard/_decorators.py index af6f82b..1d171ec 100644 --- a/src/typeguard/_decorators.py +++ b/src/typeguard/_decorators.py @@ -21,9 +21,9 @@ if TYPE_CHECKING: from typeshed.stdlib.types import _Cell - def typeguard_ignore(f: T_CallableOrType) -> T_CallableOrType: + def typeguard_ignore(arg: T_CallableOrType) -> T_CallableOrType: """This decorator is a noop during static type-checking.""" - return f + return arg else: from typing import no_type_check as typeguard_ignore # noqa: F401 From 9a73eb0fc6115f9c5b8cd7f234affb6c481cf89e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 27 Oct 2024 13:37:59 +0200 Subject: [PATCH 07/11] Dropped Python 3.8 support --- .github/workflows/test.yml | 2 +- docs/versionhistory.rst | 3 ++ pyproject.toml | 7 +++-- src/typeguard/_checkers.py | 17 ++++++------ src/typeguard/_importhook.py | 4 +-- src/typeguard/_transformer.py | 24 ++-------------- src/typeguard/_union_transformer.py | 11 +------- src/typeguard/_utils.py | 4 +-- tests/test_checkers.py | 10 ++----- tests/test_transformer.py | 43 ++++++++--------------------- tests/test_union_transformer.py | 21 +++++++------- 11 files changed, 46 insertions(+), 100 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index df3ec80..6cbe7dd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", pypy-3.10] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", pypy-3.10] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index cb1563c..0e79eef 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -4,6 +4,9 @@ Version history This library adheres to `Semantic Versioning 2.0 `_. +**UNRELEASED** + +- Dropped Python 3.8 support - Changed the signature of ``typeguard_ignore()`` to be compatible with ``typing.no_type_check()`` (PR by @jolaf) diff --git a/pyproject.toml b/pyproject.toml index 5df4ab0..9591049 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,14 +17,13 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", ] -requires-python = ">= 3.8" +requires-python = ">= 3.9" dependencies = [ "importlib_metadata >= 3.6; python_version < '3.10'", "typing_extensions >= 4.10.0", @@ -90,6 +89,8 @@ extend-select = [ ignore = [ "S307", "B008", + "UP006", + "UP035", ] [tool.mypy] @@ -98,7 +99,7 @@ strict = true pretty = true [tool.tox] -env_list = ["py38", "py39", "py310", "py311", "py312", "py313"] +env_list = ["py39", "py310", "py311", "py312", "py313"] skip_missing_interpreters = true [tool.tox.env_run_base] diff --git a/src/typeguard/_checkers.py b/src/typeguard/_checkers.py index 44ca34b..fa7df9f 100644 --- a/src/typeguard/_checkers.py +++ b/src/typeguard/_checkers.py @@ -6,6 +6,7 @@ import types import typing import warnings +from collections.abc import Mapping, MutableMapping, Sequence from enum import Enum from inspect import Parameter, isclass, isfunction from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase @@ -14,17 +15,15 @@ from typing import ( IO, AbstractSet, + Annotated, Any, BinaryIO, Callable, Dict, ForwardRef, List, - Mapping, - MutableMapping, NewType, Optional, - Sequence, Set, TextIO, Tuple, @@ -49,7 +48,6 @@ if sys.version_info >= (3, 11): from typing import ( - Annotated, NotRequired, TypeAlias, get_args, @@ -58,14 +56,13 @@ SubclassableAny = Any else: + from typing_extensions import Any as SubclassableAny from typing_extensions import ( - Annotated, NotRequired, TypeAlias, get_args, get_origin, ) - from typing_extensions import Any as SubclassableAny if sys.version_info >= (3, 10): from importlib.metadata import entry_points @@ -82,9 +79,11 @@ ] checker_lookup_functions: list[TypeCheckLookupCallback] = [] -generic_alias_types: tuple[type, ...] = (type(List), type(List[Any])) -if sys.version_info >= (3, 9): - generic_alias_types += (types.GenericAlias,) +generic_alias_types: tuple[type, ...] = ( + type(List), + type(List[Any]), + types.GenericAlias, +) # Sentinel _missing = object() diff --git a/src/typeguard/_importhook.py b/src/typeguard/_importhook.py index 8590540..0d1c627 100644 --- a/src/typeguard/_importhook.py +++ b/src/typeguard/_importhook.py @@ -3,14 +3,14 @@ import ast import sys import types -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence from importlib.abc import MetaPathFinder from importlib.machinery import ModuleSpec, SourceFileLoader from importlib.util import cache_from_source, decode_source from inspect import isclass from os import PathLike from types import CodeType, ModuleType, TracebackType -from typing import Sequence, TypeVar +from typing import TypeVar from unittest.mock import patch from ._config import global_config diff --git a/src/typeguard/_transformer.py b/src/typeguard/_transformer.py index 13ac363..13d2cf0 100644 --- a/src/typeguard/_transformer.py +++ b/src/typeguard/_transformer.py @@ -472,12 +472,6 @@ def visit_Name(self, node: Name) -> Any: if self._memo.is_ignored_name(node): return None - if sys.version_info < (3, 9): - for typename, substitute in self.type_substitutions.items(): - if self._memo.name_matches(node, typename): - new_node = self.transformer._get_import(*substitute) - return copy_location(new_node, node) - return node def visit_Call(self, node: Call) -> Any: @@ -748,11 +742,7 @@ def visit_FunctionDef( if node.args.vararg: annotation_ = self._convert_annotation(node.args.vararg.annotation) if annotation_: - if sys.version_info >= (3, 9): - container = Name("tuple", ctx=Load()) - else: - container = self._get_import("typing", "Tuple") - + container = Name("tuple", ctx=Load()) subscript_slice: Tuple | Index = Tuple( [ annotation_, @@ -760,9 +750,6 @@ def visit_FunctionDef( ], ctx=Load(), ) - if sys.version_info < (3, 9): - subscript_slice = Index(subscript_slice, ctx=Load()) - arg_annotations[node.args.vararg.arg] = Subscript( container, subscript_slice, ctx=Load() ) @@ -770,11 +757,7 @@ def visit_FunctionDef( if node.args.kwarg: annotation_ = self._convert_annotation(node.args.kwarg.annotation) if annotation_: - if sys.version_info >= (3, 9): - container = Name("dict", ctx=Load()) - else: - container = self._get_import("typing", "Dict") - + container = Name("dict", ctx=Load()) subscript_slice = Tuple( [ Name("str", ctx=Load()), @@ -782,9 +765,6 @@ def visit_FunctionDef( ], ctx=Load(), ) - if sys.version_info < (3, 9): - subscript_slice = Index(subscript_slice, ctx=Load()) - arg_annotations[node.args.kwarg.arg] = Subscript( container, subscript_slice, ctx=Load() ) diff --git a/src/typeguard/_union_transformer.py b/src/typeguard/_union_transformer.py index 19617e6..d0a3ddf 100644 --- a/src/typeguard/_union_transformer.py +++ b/src/typeguard/_union_transformer.py @@ -18,16 +18,7 @@ ) from ast import Tuple as ASTTuple from types import CodeType -from typing import Any, Dict, FrozenSet, List, Set, Tuple, Union - -type_substitutions = { - "dict": Dict, - "list": List, - "tuple": Tuple, - "set": Set, - "frozenset": FrozenSet, - "Union": Union, -} +from typing import Any class UnionTransformer(NodeTransformer): diff --git a/src/typeguard/_utils.py b/src/typeguard/_utils.py index 9bcc841..e8f9b03 100644 --- a/src/typeguard/_utils.py +++ b/src/typeguard/_utils.py @@ -35,7 +35,7 @@ def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: ) def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: - from ._union_transformer import compile_type_hint, type_substitutions + from ._union_transformer import compile_type_hint if not forwardref.__forward_evaluated__: forwardref.__forward_code__ = compile_type_hint(forwardref.__forward_arg__) @@ -47,8 +47,6 @@ def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: # Try again, with the type substitutions (list -> List etc.) in place new_globals = memo.globals.copy() new_globals.setdefault("Union", Union) - if sys.version_info < (3, 9): - new_globals.update(type_substitutions) return forwardref._evaluate( new_globals, memo.locals or new_globals, *evaluate_extra_args diff --git a/tests/test_checkers.py b/tests/test_checkers.py index f780964..dba7a7a 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -7,6 +7,7 @@ from typing import ( IO, AbstractSet, + Annotated, Any, AnyStr, BinaryIO, @@ -75,11 +76,6 @@ else: from typing_extensions import Concatenate, ParamSpec, TypeGuard -if sys.version_info >= (3, 9): - from typing import Annotated -else: - from typing_extensions import Annotated - P = ParamSpec("P") @@ -941,9 +937,7 @@ def test_union_typevar(self): @pytest.mark.parametrize("check_against", [type, Type[Any]]) def test_generic_aliase(self, check_against): - if sys.version_info >= (3, 9): - check_type(dict[str, str], check_against) - + check_type(dict[str, str], check_against) check_type(Dict, check_against) check_type(Dict[str, str], check_against) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 15cf9d4..9248d50 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -1,16 +1,11 @@ import sys -from ast import parse +from ast import parse, unparse from textwrap import dedent import pytest from typeguard._transformer import TypeguardTransformer -if sys.version_info >= (3, 9): - from ast import unparse -else: - pytest.skip("Requires Python 3.9 or newer", allow_module_level=True) - def test_arguments_only() -> None: node = parse( @@ -1166,27 +1161,20 @@ def foo(*args: int) -> None: ) TypeguardTransformer().visit(node) - if sys.version_info < (3, 9): - extra_import = "from typing import Tuple\n" - tuple_type = "Tuple" - else: - extra_import = "" - tuple_type = "tuple" - assert ( unparse(node) == dedent( - f""" + """ from typeguard import TypeCheckMemo from typeguard._functions import check_argument_types, \ check_variable_assignment - {extra_import} + def foo(*args: int) -> None: memo = TypeCheckMemo(globals(), locals()) - check_argument_types('foo', {{'args': (args, \ -{tuple_type}[int, ...])}}, memo) + check_argument_types('foo', {'args': (args, \ +tuple[int, ...])}, memo) args = check_variable_assignment((5,), 'args', \ -{tuple_type}[int, ...], memo) +tuple[int, ...], memo) """ ).strip() ) @@ -1202,27 +1190,20 @@ def foo(**kwargs: int) -> None: ) TypeguardTransformer().visit(node) - if sys.version_info < (3, 9): - extra_import = "from typing import Dict\n" - dict_type = "Dict" - else: - extra_import = "" - dict_type = "dict" - assert ( unparse(node) == dedent( - f""" + """ from typeguard import TypeCheckMemo from typeguard._functions import check_argument_types, \ check_variable_assignment - {extra_import} + def foo(**kwargs: int) -> None: memo = TypeCheckMemo(globals(), locals()) - check_argument_types('foo', {{'kwargs': (kwargs, \ -{dict_type}[str, int])}}, memo) - kwargs = check_variable_assignment({{'a': 5}}, 'kwargs', \ -{dict_type}[str, int], memo) + check_argument_types('foo', {'kwargs': (kwargs, \ +dict[str, int])}, memo) + kwargs = check_variable_assignment({'a': 5}, 'kwargs', \ +dict[str, int], memo) """ ).strip() ) diff --git a/tests/test_union_transformer.py b/tests/test_union_transformer.py index dc45679..e6dcd25 100644 --- a/tests/test_union_transformer.py +++ b/tests/test_union_transformer.py @@ -1,13 +1,17 @@ import typing -from typing import Callable +from typing import Callable, Union import pytest from typing_extensions import Literal -from typeguard._union_transformer import compile_type_hint, type_substitutions +from typeguard._union_transformer import compile_type_hint -eval_globals = {"Callable": Callable, "Literal": Literal, "typing": typing} -eval_globals.update(type_substitutions) +eval_globals = { + "Callable": Callable, + "Literal": Literal, + "typing": typing, + "Union": Union, +} @pytest.mark.parametrize( @@ -15,12 +19,12 @@ [ ["str | int", "Union[str, int]"], ["str | int | bytes", "Union[str, int, bytes]"], - ["str | Union[int | bytes, set]", "Union[str, int, bytes, Set]"], + ["str | Union[int | bytes, set]", "Union[str, int, bytes, set]"], ["str | int | Callable[..., bytes]", "Union[str, int, Callable[..., bytes]]"], ["str | int | Callable[[], bytes]", "Union[str, int, Callable[[], bytes]]"], [ "str | int | Callable[[], bytes | set]", - "Union[str, int, Callable[[], Union[bytes, Set]]]", + "Union[str, int, Callable[[], Union[bytes, set]]]", ], ["str | int | Literal['foo']", "Union[str, int, Literal['foo']]"], ["str | int | Literal[-1]", "Union[str, int, Literal[-1]]"], @@ -29,11 +33,6 @@ 'str | int | Literal["It\'s a string \'\\""]', "Union[str, int, Literal['It\\'s a string \\'\"']]", ], - [ - "typing.Tuple | typing.List | Literal[-1]", - "Union[Tuple, List, Literal[-1]]", - ], - ["tuple[int, ...]", "Tuple[int, ...]"], ], ) def test_union_transformer(inputval: str, expected: str) -> None: From 889ad53b3d9dd1d21cb54938168dbb037b0d24e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 3 Nov 2024 00:05:13 +0200 Subject: [PATCH 08/11] Fixed checking of variable assignments involving tuple unpacking This also unified all variable checking across different assignment types (annotation assignment, augmented assignment and any other kind of assignment) Fixes #486. --- docs/versionhistory.rst | 2 + src/typeguard/_functions.py | 87 ++++++++--------- src/typeguard/_transformer.py | 146 ++++++++++++++-------------- src/typeguard/_union_transformer.py | 7 +- tests/test_transformer.py | 101 +++++++++++++++---- 5 files changed, 204 insertions(+), 139 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 0e79eef..1815197 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -9,6 +9,8 @@ This library adheres to - Dropped Python 3.8 support - Changed the signature of ``typeguard_ignore()`` to be compatible with ``typing.no_type_check()`` (PR by @jolaf) +- Fixed checking of variable assignments involving tuple unpacking + (`#486 `_) **4.4.0** (2024-10-27) diff --git a/src/typeguard/_functions.py b/src/typeguard/_functions.py index 2849785..ca21c14 100644 --- a/src/typeguard/_functions.py +++ b/src/typeguard/_functions.py @@ -2,6 +2,7 @@ import sys import warnings +from collections.abc import Sequence from typing import Any, Callable, NoReturn, TypeVar, Union, overload from . import _suppression @@ -242,59 +243,53 @@ def check_yield_type( def check_variable_assignment( - value: object, varname: str, annotation: Any, memo: TypeCheckMemo + value: Any, targets: Sequence[list[tuple[str, Any]]], memo: TypeCheckMemo ) -> Any: if _suppression.type_checks_suppressed: return value - try: - check_type_internal(value, annotation, memo) - except TypeCheckError as exc: - qualname = qualified_name(value, add_class_prefix=True) - exc.append_path_element(f"value assigned to {varname} ({qualname})") - if memo.config.typecheck_fail_callback: - memo.config.typecheck_fail_callback(exc, memo) - else: - raise - - return value - + value_to_return = value + for target in targets: + star_variable_index = next( + (i for i, (varname, _) in enumerate(target) if varname.startswith("*")), + None, + ) + if star_variable_index is not None: + value_to_return = list(value) + remaining_vars = len(target) - 1 - star_variable_index + end_index = len(value_to_return) - remaining_vars + values_to_check = ( + value_to_return[:star_variable_index] + + [value_to_return[star_variable_index:end_index]] + + value_to_return[end_index:] + ) + elif len(target) > 1: + values_to_check = value_to_return = [] + iterator = iter(value) + for _ in target: + try: + values_to_check.append(next(iterator)) + except StopIteration: + raise ValueError( + f"not enough values to unpack (expected {len(target)}, got " + f"{len(values_to_check)})" + ) from None -def check_multi_variable_assignment( - value: Any, targets: list[dict[str, Any]], memo: TypeCheckMemo -) -> Any: - if max(len(target) for target in targets) == 1: - iterated_values = [value] - else: - iterated_values = list(value) - - if not _suppression.type_checks_suppressed: - for expected_types in targets: - value_index = 0 - for ann_index, (varname, expected_type) in enumerate( - expected_types.items() - ): - if varname.startswith("*"): - varname = varname[1:] - keys_left = len(expected_types) - 1 - ann_index - next_value_index = len(iterated_values) - keys_left - obj: object = iterated_values[value_index:next_value_index] - value_index = next_value_index + else: + values_to_check = [value] + + for val, (varname, annotation) in zip(values_to_check, target): + try: + check_type_internal(val, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(val, add_class_prefix=True) + exc.append_path_element(f"value assigned to {varname} ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) else: - obj = iterated_values[value_index] - value_index += 1 + raise - try: - check_type_internal(obj, expected_type, memo) - except TypeCheckError as exc: - qualname = qualified_name(obj, add_class_prefix=True) - exc.append_path_element(f"value assigned to {varname} ({qualname})") - if memo.config.typecheck_fail_callback: - memo.config.typecheck_fail_callback(exc, memo) - else: - raise - - return iterated_values[0] if len(iterated_values) == 1 else iterated_values + return value_to_return def warn_on_error(exc: TypeCheckError, memo: TypeCheckMemo) -> None: diff --git a/src/typeguard/_transformer.py b/src/typeguard/_transformer.py index 13d2cf0..937b6b5 100644 --- a/src/typeguard/_transformer.py +++ b/src/typeguard/_transformer.py @@ -28,7 +28,6 @@ If, Import, ImportFrom, - Index, List, Load, LShift, @@ -389,9 +388,7 @@ def visit_BinOp(self, node: BinOp) -> Any: union_name = self.transformer._get_import("typing", "Union") return Subscript( value=union_name, - slice=Index( - Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() - ), + slice=Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load(), ) @@ -410,24 +407,18 @@ def visit_Subscript(self, node: Subscript) -> Any: # The subscript of typing(_extensions).Literal can be any arbitrary string, so # don't try to evaluate it as code if node.slice: - if isinstance(node.slice, Index): - # Python 3.8 - slice_value = node.slice.value # type: ignore[attr-defined] - else: - slice_value = node.slice - - if isinstance(slice_value, Tuple): + if isinstance(node.slice, Tuple): if self._memo.name_matches(node.value, *annotated_names): # Only treat the first argument to typing.Annotated as a potential # forward reference items = cast( typing.List[expr], - [self.visit(slice_value.elts[0])] + slice_value.elts[1:], + [self.visit(node.slice.elts[0])] + node.slice.elts[1:], ) else: items = cast( typing.List[expr], - [self.visit(item) for item in slice_value.elts], + [self.visit(item) for item in node.slice.elts], ) # If this is a Union and any of the items is Any, erase the entire @@ -450,7 +441,7 @@ def visit_Subscript(self, node: Subscript) -> Any: if item is None: items[index] = self.transformer._get_import("typing", "Any") - slice_value.elts = items + node.slice.elts = items else: self.generic_visit(node) @@ -542,18 +533,10 @@ def _use_memo( return_annotation, *generator_names ): if isinstance(return_annotation, Subscript): - annotation_slice = return_annotation.slice - - # Python < 3.9 - if isinstance(annotation_slice, Index): - annotation_slice = ( - annotation_slice.value # type: ignore[attr-defined] - ) - - if isinstance(annotation_slice, Tuple): - items = annotation_slice.elts + if isinstance(return_annotation.slice, Tuple): + items = return_annotation.slice.elts else: - items = [annotation_slice] + items = [return_annotation.slice] if len(items) > 0: new_memo.yield_annotation = self._convert_annotation( @@ -743,7 +726,7 @@ def visit_FunctionDef( annotation_ = self._convert_annotation(node.args.vararg.annotation) if annotation_: container = Name("tuple", ctx=Load()) - subscript_slice: Tuple | Index = Tuple( + subscript_slice = Tuple( [ annotation_, Constant(Ellipsis), @@ -1024,12 +1007,25 @@ def visit_AnnAssign(self, node: AnnAssign) -> Any: func_name = self._get_import( "typeguard._functions", "check_variable_assignment" ) + targets_arg = List( + [ + List( + [ + Tuple( + [Constant(node.target.id), annotation], + ctx=Load(), + ) + ], + ctx=Load(), + ) + ], + ctx=Load(), + ) node.value = Call( func_name, [ node.value, - Constant(node.target.id), - annotation, + targets_arg, self._memo.get_memo_name(), ], [], @@ -1047,7 +1043,7 @@ def visit_Assign(self, node: Assign) -> Any: # Only instrument function-local assignments if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)): - targets: list[dict[Constant, expr | None]] = [] + preliminary_targets: list[list[tuple[Constant, expr | None]]] = [] check_required = False for target in node.targets: elts: Sequence[expr] @@ -1058,63 +1054,63 @@ def visit_Assign(self, node: Assign) -> Any: else: continue - annotations_: dict[Constant, expr | None] = {} + annotations_: list[tuple[Constant, expr | None]] = [] for exp in elts: prefix = "" if isinstance(exp, Starred): exp = exp.value prefix = "*" + path: list[str] = [] + while isinstance(exp, Attribute): + path.insert(0, exp.attr) + exp = exp.value + if isinstance(exp, Name): - self._memo.ignored_names.add(exp.id) - name = prefix + exp.id + if not path: + self._memo.ignored_names.add(exp.id) + + path.insert(0, exp.id) + name = prefix + ".".join(path) annotation = self._memo.variable_annotations.get(exp.id) if annotation: - annotations_[Constant(name)] = annotation + annotations_.append((Constant(name), annotation)) check_required = True else: - annotations_[Constant(name)] = None + annotations_.append((Constant(name), None)) - targets.append(annotations_) + preliminary_targets.append(annotations_) if check_required: # Replace missing annotations with typing.Any - for item in targets: - for key, expression in item.items(): + targets: list[list[tuple[Constant, expr]]] = [] + for items in preliminary_targets: + target_list: list[tuple[Constant, expr]] = [] + targets.append(target_list) + for key, expression in items: if expression is None: - item[key] = self._get_import("typing", "Any") + target_list.append((key, self._get_import("typing", "Any"))) + else: + target_list.append((key, expression)) - if len(targets) == 1 and len(targets[0]) == 1: - func_name = self._get_import( - "typeguard._functions", "check_variable_assignment" - ) - target_varname = next(iter(targets[0])) - node.value = Call( - func_name, - [ - node.value, - target_varname, - targets[0][target_varname], - self._memo.get_memo_name(), - ], - [], - ) - elif targets: - func_name = self._get_import( - "typeguard._functions", "check_multi_variable_assignment" - ) - targets_arg = List( - [ - Dict(keys=list(target), values=list(target.values())) - for target in targets - ], - ctx=Load(), - ) - node.value = Call( - func_name, - [node.value, targets_arg, self._memo.get_memo_name()], - [], - ) + func_name = self._get_import( + "typeguard._functions", "check_variable_assignment" + ) + targets_arg = List( + [ + List( + [Tuple([name, ann], ctx=Load()) for name, ann in target], + ctx=Load(), + ) + for target in targets + ], + ctx=Load(), + ) + node.value = Call( + func_name, + [node.value, targets_arg, self._memo.get_memo_name()], + [], + ) return node @@ -1175,12 +1171,20 @@ def visit_AugAssign(self, node: AugAssign) -> Any: operator_call = Call( operator_func, [Name(node.target.id, ctx=Load()), node.value], [] ) + targets_arg = List( + [ + List( + [Tuple([Constant(node.target.id), annotation], ctx=Load())], + ctx=Load(), + ) + ], + ctx=Load(), + ) check_call = Call( self._get_import("typeguard._functions", "check_variable_assignment"), [ operator_call, - Constant(node.target.id), - annotation, + targets_arg, self._memo.get_memo_name(), ], [], diff --git a/src/typeguard/_union_transformer.py b/src/typeguard/_union_transformer.py index d0a3ddf..1c296d3 100644 --- a/src/typeguard/_union_transformer.py +++ b/src/typeguard/_union_transformer.py @@ -8,15 +8,14 @@ from ast import ( BinOp, BitOr, - Index, Load, Name, NodeTransformer, Subscript, + Tuple, fix_missing_locations, parse, ) -from ast import Tuple as ASTTuple from types import CodeType from typing import Any @@ -30,9 +29,7 @@ def visit_BinOp(self, node: BinOp) -> Any: if isinstance(node.op, BitOr): return Subscript( value=self.union_name, - slice=Index( - ASTTuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() - ), + slice=Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load(), ) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 9248d50..3cf735d 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -967,7 +967,7 @@ def foo(x: Any) -> None: def foo(x: Any) -> None: memo = TypeCheckMemo(globals(), locals()) y: FooBar = x - z: list[FooBar] = check_variable_assignment([y], 'z', list, \ + z: list[FooBar] = check_variable_assignment([y], [[('z', list)]], \ memo) """ ).strip() @@ -1145,7 +1145,8 @@ def foo() -> None: def foo() -> None: memo = TypeCheckMemo(globals(), locals()) - x: int = check_variable_assignment(otherfunc(), 'x', int, memo) + x: int = check_variable_assignment(otherfunc(), [[('x', int)]], \ +memo) """ ).strip() ) @@ -1173,8 +1174,8 @@ def foo(*args: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'args': (args, \ tuple[int, ...])}, memo) - args = check_variable_assignment((5,), 'args', \ -tuple[int, ...], memo) + args = check_variable_assignment((5,), \ +[[('args', tuple[int, ...])]], memo) """ ).strip() ) @@ -1202,8 +1203,8 @@ def foo(**kwargs: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'kwargs': (kwargs, \ dict[str, int])}, memo) - kwargs = check_variable_assignment({'a': 5}, 'kwargs', \ -dict[str, int], memo) + kwargs = check_variable_assignment({'a': 5}, \ +[[('kwargs', dict[str, int])]], memo) """ ).strip() ) @@ -1232,8 +1233,8 @@ def foo() -> None: def foo() -> None: memo = TypeCheckMemo(globals(), locals()) - x: int | str = check_variable_assignment(otherfunc(), 'x', \ -Union_[int, str], memo) + x: int | str = check_variable_assignment(otherfunc(), \ +[[('x', Union_[int, str])]], memo) """ ).strip() ) @@ -1256,15 +1257,15 @@ def foo() -> None: == dedent( f""" from typeguard import TypeCheckMemo - from typeguard._functions import check_multi_variable_assignment + from typeguard._functions import check_variable_assignment from typing import Any def foo() -> None: memo = TypeCheckMemo(globals(), locals()) x: int z: bytes - {target} = check_multi_variable_assignment(otherfunc(), \ -[{{'x': int, 'y': Any, 'z': bytes}}], memo) + {target} = check_variable_assignment(otherfunc(), \ +[[('x', int), ('y', Any), ('z', bytes)]], memo) """ ).strip() ) @@ -1287,15 +1288,80 @@ def foo() -> None: == dedent( f""" from typeguard import TypeCheckMemo - from typeguard._functions import check_multi_variable_assignment + from typeguard._functions import check_variable_assignment from typing import Any def foo() -> None: memo = TypeCheckMemo(globals(), locals()) x: int z: bytes - {target} = check_multi_variable_assignment(otherfunc(), \ -[{{'x': int, '*y': Any, 'z': bytes}}], memo) + {target} = check_variable_assignment(otherfunc(), \ +[[('x', int), ('*y', Any), ('z', bytes)]], memo) + """ + ).strip() + ) + + def test_complex_multi_assign(self) -> None: + node = parse( + dedent( + """ + def foo() -> None: + x: int + z: bytes + all = x, *y, z = otherfunc() + """ + ) + ) + TypeguardTransformer().visit(node) + target = "x, *y, z" if sys.version_info >= (3, 11) else "(x, *y, z)" + assert ( + unparse(node) + == dedent( + f""" + from typeguard import TypeCheckMemo + from typeguard._functions import check_variable_assignment + from typing import Any + + def foo() -> None: + memo = TypeCheckMemo(globals(), locals()) + x: int + z: bytes + all = {target} = check_variable_assignment(otherfunc(), \ +[[('all', Any)], [('x', int), ('*y', Any), ('z', bytes)]], memo) + """ + ).strip() + ) + + def test_unpacking_assign_to_self(self) -> None: + node = parse( + dedent( + """ + class Foo: + + def foo(self) -> None: + x: int + (x, self.y) = 1, 'test' + """ + ) + ) + TypeguardTransformer().visit(node) + target = "x, self.y" if sys.version_info >= (3, 11) else "(x, self.y)" + assert ( + unparse(node) + == dedent( + f""" + from typeguard import TypeCheckMemo + from typeguard._functions import check_variable_assignment + from typing import Any + + class Foo: + + def foo(self) -> None: + memo = TypeCheckMemo(globals(), locals(), \ +self_type=self.__class__) + x: int + {target} = check_variable_assignment((1, 'test'), \ +[[('x', int), ('self.y', Any)]], memo) """ ).strip() ) @@ -1321,7 +1387,7 @@ def foo(x: int) -> None: def foo(x: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'x': (x, int)}, memo) - x = check_variable_assignment(6, 'x', int, memo) + x = check_variable_assignment(6, [[('x', int)]], memo) """ ).strip() ) @@ -1422,7 +1488,8 @@ def foo() -> None: def foo() -> None: memo = TypeCheckMemo(globals(), locals()) x: int - x = check_variable_assignment({function}(x, 6), 'x', int, memo) + x = check_variable_assignment({function}(x, 6), [[('x', int)]], \ +memo) """ ).strip() ) @@ -1471,7 +1538,7 @@ def foo(x: int) -> None: def foo(x: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'x': (x, int)}, memo) - x = check_variable_assignment(iadd(x, 6), 'x', int, memo) + x = check_variable_assignment(iadd(x, 6), [[('x', int)]], memo) """ ).strip() ) From 750286e33bcd77b6bc84f0827d96d4e0d06b9b42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 3 Nov 2024 00:07:38 +0200 Subject: [PATCH 09/11] Updated pre-commit modules --- .pre-commit-config.yaml | 6 +++--- pyproject.toml | 2 +- src/typeguard/_decorators.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 320d79a..02f80cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-case-conflict - id: check-merge-conflict @@ -14,14 +14,14 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.7.2 hooks: - id: ruff args: [--fix, --show-fixes] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.1 + rev: v1.13.0 hooks: - id: mypy additional_dependencies: [ "typing_extensions" ] diff --git a/pyproject.toml b/pyproject.toml index 9591049..7c89494 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ ignore = [ ] [tool.mypy] -python_version = "3.9" +python_version = "3.11" strict = true pretty = true diff --git a/src/typeguard/_decorators.py b/src/typeguard/_decorators.py index 1d171ec..a6c20cb 100644 --- a/src/typeguard/_decorators.py +++ b/src/typeguard/_decorators.py @@ -216,7 +216,7 @@ def typechecked( ) = None if isinstance(target, (classmethod, staticmethod)): wrapper_class = target.__class__ - target = target.__func__ + target = target.__func__ # type: ignore[assignment] retval = instrument(target) if isinstance(retval, str): From 121efd5cdb8899796e0dd38cda34d41225bb3e70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 3 Nov 2024 00:55:30 +0200 Subject: [PATCH 10/11] Fixed `TypeError` when checking a class against `type[Self]` --- docs/versionhistory.rst | 2 ++ src/typeguard/_checkers.py | 5 +++-- tests/test_checkers.py | 4 +--- tests/test_typechecked.py | 23 +++++++++++++++++++++++ 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 1815197..c6fd657 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -11,6 +11,8 @@ This library adheres to ``typing.no_type_check()`` (PR by @jolaf) - Fixed checking of variable assignments involving tuple unpacking (`#486 `_) +- Fixed ``TypeError`` when checking a class against ``type[Self]`` + (`#481 `_) **4.4.0** (2024-10-27) diff --git a/src/typeguard/_checkers.py b/src/typeguard/_checkers.py index fa7df9f..8166bf2 100644 --- a/src/typeguard/_checkers.py +++ b/src/typeguard/_checkers.py @@ -464,6 +464,8 @@ def check_class( if expected_class is Any: return + elif expected_class is typing_extensions.Self: + check_self(value, get_origin(expected_class), get_args(expected_class), memo) elif getattr(expected_class, "_is_protocol", False): check_protocol(value, expected_class, (), memo) elif isinstance(expected_class, TypeVar): @@ -847,8 +849,7 @@ def check_self( if isclass(value): if not issubclass(value, memo.self_type): raise TypeCheckError( - f"is not an instance of the self type " - f"({qualified_name(memo.self_type)})" + f"is not a subclass of the self type ({qualified_name(memo.self_type)})" ) elif not isinstance(value, memo.self_type): raise TypeCheckError( diff --git a/tests/test_checkers.py b/tests/test_checkers.py index dba7a7a..d2bcd81 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -36,6 +36,7 @@ ) import pytest +from typing_extensions import LiteralString from typeguard import ( CollectionCheckStrategy, @@ -64,12 +65,9 @@ ) if sys.version_info >= (3, 11): - from typing import LiteralString - SubclassableAny = Any else: from typing_extensions import Any as SubclassableAny - from typing_extensions import LiteralString if sys.version_info >= (3, 10): from typing import Concatenate, ParamSpec, TypeGuard diff --git a/tests/test_typechecked.py b/tests/test_typechecked.py index fd59230..d56f3ae 100644 --- a/tests/test_typechecked.py +++ b/tests/test_typechecked.py @@ -456,6 +456,29 @@ def method(cls, another: Self) -> None: rf"test_classmethod_arg_invalid\.\.Foo\)" ) + def test_self_type_valid(self): + class Foo: + @typechecked + def method(cls, subclass: type[Self]) -> None: + pass + + class Bar(Foo): + pass + + Foo().method(Bar) + + def test_self_type_invalid(self): + class Foo: + @typechecked + def method(cls, subclass: type[Self]) -> None: + pass + + pytest.raises(TypeCheckError, Foo().method, int).match( + rf'argument "subclass" \(class int\) is not a subclass of the self type ' + rf"\({__name__}\.{self.__class__.__name__}\." + rf"test_self_type_invalid\.\.Foo\)" + ) + class TestMock: def test_mock_argument(self): From 28dafeca1b39b8c19217ce09536e349c94734fa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 3 Nov 2024 02:13:13 +0200 Subject: [PATCH 11/11] Fixed checking of protocols on the class level Fixes #498. --- docs/versionhistory.rst | 2 ++ src/typeguard/_checkers.py | 25 ++++++++++++++----------- tests/test_checkers.py | 11 +++++++++-- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index c6fd657..28b92a3 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -13,6 +13,8 @@ This library adheres to (`#486 `_) - Fixed ``TypeError`` when checking a class against ``type[Self]`` (`#481 `_) +- Fixed checking of protocols on the class level (against ``type[SomeProtocol]``) + (`#498 `_) **4.4.0** (2024-10-27) diff --git a/src/typeguard/_checkers.py b/src/typeguard/_checkers.py index 8166bf2..1ab6ee2 100644 --- a/src/typeguard/_checkers.py +++ b/src/typeguard/_checkers.py @@ -635,10 +635,8 @@ def check_io( raise TypeCheckError("is not an I/O object") -def check_signature_compatible( - subject_callable: Callable[..., Any], protocol: type, attrname: str -) -> None: - subject_sig = inspect.signature(subject_callable) +def check_signature_compatible(subject: type, protocol: type, attrname: str) -> None: + subject_sig = inspect.signature(getattr(subject, attrname)) protocol_sig = inspect.signature(getattr(protocol, attrname)) protocol_type: typing.Literal["instance", "class", "static"] = "instance" subject_type: typing.Literal["instance", "class", "static"] = "instance" @@ -652,12 +650,12 @@ def check_signature_compatible( protocol_type = "class" # Check if the subject-side method is a class method or static method - if inspect.ismethod(subject_callable) and inspect.isclass( - subject_callable.__self__ - ): - subject_type = "class" - elif not hasattr(subject_callable, "__self__"): - subject_type = "static" + if attrname in subject.__dict__: + descriptor = subject.__dict__[attrname] + if isinstance(descriptor, staticmethod): + subject_type = "static" + elif isinstance(descriptor, classmethod): + subject_type = "class" if protocol_type == "instance" and subject_type != "instance": raise TypeCheckError( @@ -714,6 +712,10 @@ def check_signature_compatible( if protocol_type == "instance": protocol_args.pop(0) + # Remove the "self" parameter from the subject arguments to match + if subject_type == "instance": + subject_args.pop(0) + for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args): if protocol_arg is None: if subject_arg.default is Parameter.empty: @@ -818,8 +820,9 @@ def check_protocol( # TODO: implement assignability checks for parameter and return value # annotations + subject = value if isclass(value) else value.__class__ try: - check_signature_compatible(subject_member, origin_type, attrname) + check_signature_compatible(subject, origin_type, attrname) except TypeCheckError as exc: raise TypeCheckError( f"is not compatible with the {origin_type.__qualname__} " diff --git a/tests/test_checkers.py b/tests/test_checkers.py index d2bcd81..526e94f 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -1069,7 +1069,11 @@ def test_raises_for_non_member(self, subject: object, predicate_type: type) -> N class TestProtocol: - def test_success(self, typing_provider: Any) -> None: + @pytest.mark.parametrize( + "instantiate", + [pytest.param(True, id="instance"), pytest.param(False, id="class")], + ) + def test_success(self, typing_provider: Any, instantiate: bool) -> None: class MyProtocol(Protocol): member: int @@ -1129,7 +1133,10 @@ def my_static_method(cls, x: int, y: str) -> None: def my_class_method(x: int, y: str) -> None: pass - check_type(Foo(), MyProtocol) + if instantiate: + check_type(Foo(), MyProtocol) + else: + check_type(Foo, type[MyProtocol]) @pytest.mark.parametrize("has_member", [True, False]) def test_member_checks(self, has_member: bool) -> None: