diff --git a/typing_extensions/CHANGELOG b/typing_extensions/CHANGELOG index a9a59804..970bbd4a 100644 --- a/typing_extensions/CHANGELOG +++ b/typing_extensions/CHANGELOG @@ -1,6 +1,9 @@ # Unreleased -- Add `typing.assert_type`. Backport from bpo-46480. +- Add `typing_extensions.get_overloads` and + `typing_extensions.clear_overloads`, and add registry support to + `typing_extensions.overload`. Backport from python/cpython#89263. +- Add `typing_extensions.assert_type`. Backport from bpo-46480. - Drop support for Python 3.6. Original patch by Adam Turner (@AA-Turner). # Release 4.1.1 (February 13, 2022) diff --git a/typing_extensions/README.rst b/typing_extensions/README.rst index 9abed044..3a23b755 100644 --- a/typing_extensions/README.rst +++ b/typing_extensions/README.rst @@ -47,6 +47,8 @@ This module currently contains the following: - ``assert_never`` - ``assert_type`` + - ``clear_overloads`` + - ``get_overloads`` - ``LiteralString`` (see PEP 675) - ``Never`` - ``NotRequired`` (see PEP 655) @@ -122,6 +124,10 @@ Certain objects were changed after they were added to ``typing``, and Python 3.8 and lack support for ``ParamSpecArgs`` and ``ParamSpecKwargs`` in 3.9. - ``@final`` was changed in Python 3.11 to set the ``.__final__`` attribute. +- ``@overload`` was changed in Python 3.11 to make function overloads + introspectable at runtime. In order to access overloads with + ``typing_extensions.get_overloads()``, you must use + ``@typing_extensions.overload``. There are a few types whose interface was modified between different versions of typing. For example, ``typing.Sequence`` was modified to diff --git a/typing_extensions/src/test_typing_extensions.py b/typing_extensions/src/test_typing_extensions.py index 1439e517..ab03244d 100644 --- a/typing_extensions/src/test_typing_extensions.py +++ b/typing_extensions/src/test_typing_extensions.py @@ -3,6 +3,7 @@ import abc import contextlib import collections +from collections import defaultdict import collections.abc from functools import lru_cache import inspect @@ -10,6 +11,7 @@ import subprocess import types from unittest import TestCase, main, skipUnless, skipIf +from unittest.mock import patch from test import ann_module, ann_module2, ann_module3 import typing from typing import TypeVar, Optional, Union, Any, AnyStr @@ -21,9 +23,10 @@ from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self from typing_extensions import TypeAlias, ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs, TypeGuard from typing_extensions import Awaitable, AsyncIterator, AsyncContextManager, Required, NotRequired -from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, overload, final, is_typeddict +from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, final, is_typeddict from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString from typing_extensions import assert_type, get_type_hints, get_origin, get_args +from typing_extensions import clear_overloads, get_overloads, overload # Flags used to mark tests that only apply after a specific # version of the typing module. @@ -403,6 +406,20 @@ def test_no_multiple_subscripts(self): Literal[1][1] +class MethodHolder: + @classmethod + def clsmethod(cls): ... + @staticmethod + def stmethod(): ... + def method(self): ... + + +if TYPING_3_11_0: + registry_holder = typing +else: + registry_holder = typing_extensions + + class OverloadTests(BaseTestCase): def test_overload_fails(self): @@ -424,6 +441,61 @@ def blah(): blah() + def set_up_overloads(self): + def blah(): + pass + + overload1 = blah + overload(blah) + + def blah(): + pass + + overload2 = blah + overload(blah) + + def blah(): + pass + + return blah, [overload1, overload2] + + # Make sure we don't clear the global overload registry + @patch( + f"{registry_holder.__name__}._overload_registry", + defaultdict(lambda: defaultdict(dict)) + ) + def test_overload_registry(self): + registry = registry_holder._overload_registry + # The registry starts out empty + self.assertEqual(registry, {}) + + impl, overloads = self.set_up_overloads() + self.assertNotEqual(registry, {}) + self.assertEqual(list(get_overloads(impl)), overloads) + + def some_other_func(): pass + overload(some_other_func) + other_overload = some_other_func + def some_other_func(): pass + self.assertEqual(list(get_overloads(some_other_func)), [other_overload]) + + # Make sure that after we clear all overloads, the registry is + # completely empty. + clear_overloads() + self.assertEqual(registry, {}) + self.assertEqual(get_overloads(impl), []) + + # Querying a function with no overloads shouldn't change the registry. + def the_only_one(): pass + self.assertEqual(get_overloads(the_only_one), []) + self.assertEqual(registry, {}) + + def test_overload_registry_repeated(self): + for _ in range(2): + impl, overloads = self.set_up_overloads() + + self.assertEqual(list(get_overloads(impl)), overloads) + class AssertTypeTests(BaseTestCase): diff --git a/typing_extensions/src/typing_extensions.py b/typing_extensions/src/typing_extensions.py index d5e40497..49110999 100644 --- a/typing_extensions/src/typing_extensions.py +++ b/typing_extensions/src/typing_extensions.py @@ -1,6 +1,7 @@ import abc import collections import collections.abc +import functools import operator import sys import types as _types @@ -46,7 +47,9 @@ 'Annotated', 'assert_never', 'assert_type', + 'clear_overloads', 'dataclass_transform', + 'get_overloads', 'final', 'get_args', 'get_origin', @@ -249,7 +252,72 @@ def __getitem__(self, parameters): _overload_dummy = typing._overload_dummy # noqa -overload = typing.overload + + +if hasattr(typing, "get_overloads"): # 3.11+ + overload = typing.overload + get_overloads = typing.get_overloads + clear_overloads = typing.clear_overloads +else: + # {module: {qualname: {firstlineno: func}}} + _overload_registry = collections.defaultdict( + functools.partial(collections.defaultdict, dict) + ) + + def overload(func): + """Decorator for overloaded functions/methods. + + In a stub file, place two or more stub definitions for the same + function in a row, each decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + + In a non-stub file (i.e. a regular .py file), do the same but + follow it with an implementation. The implementation should *not* + be decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + def utf8(value): + # implementation goes here + + The overloads for a function can be retrieved at runtime using the + get_overloads() function. + """ + # classmethod and staticmethod + f = getattr(func, "__func__", func) + try: + _overload_registry[f.__module__][f.__qualname__][ + f.__code__.co_firstlineno + ] = func + except AttributeError: + # Not a normal function; ignore. + pass + return _overload_dummy + + def get_overloads(func): + """Return all defined overloads for *func* as a sequence.""" + # classmethod and staticmethod + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return [] + mod_dict = _overload_registry[f.__module__] + if f.__qualname__ not in mod_dict: + return [] + return list(mod_dict[f.__qualname__].values()) + + def clear_overloads(): + """Clear all overloads in the registry.""" + _overload_registry.clear() # This is not a real generic class. Don't use outside annotations.