diff --git a/asyncstdlib/_lrucache.py b/asyncstdlib/_lrucache.py index ef5486d..ed1abc0 100644 --- a/asyncstdlib/_lrucache.py +++ b/asyncstdlib/_lrucache.py @@ -6,7 +6,10 @@ especially when they might not apply to PyPy. """ +from __future__ import annotations from typing import ( + Generic, + TypeVar, NamedTuple, Callable, Any, @@ -72,6 +75,14 @@ def __wrapped__(self) -> AC: """The callable wrapped by this cache""" raise NotImplementedError + def __get__( + self: LRUAsyncCallable[Any], instance: object, owner: Optional[type] = None + ) -> Any: + """Descriptor ``__get__`` for caches to bind them on lookup""" + if instance is None: + return self + return LRUAsyncBoundCallable(self, instance) + #: Get the result of ``await __wrapped__(...)`` from the cache or evaluation __call__: AC @@ -106,23 +117,37 @@ def cache_discard(self, *args: Any, **kwargs: Any) -> None: ... -class LRUAsyncBoundCallable(LRUAsyncCallable[AC]): +# these are fake and only exist for placeholders +S = TypeVar("S") +S2 = TypeVar("S2") +P = TypeVar("P") +R = TypeVar("R") + + +class LRUAsyncBoundCallable(Generic[S, P, R]): # type: ignore[reportInvalidTypeVarUse] """A :py:class:`~.LRUAsyncCallable` that is bound like a method""" __slots__ = ("__weakref__", "_lru", "__self__") - def __init__(self, lru: LRUAsyncCallable[AC], __self__: object): + def __init__(self, lru: LRUAsyncCallable[Any], __self__: object): self._lru = lru self.__self__ = __self__ @property - def __wrapped__(self) -> AC: + def __wrapped__(self) -> Any: return self._lru.__wrapped__ @property - def __func__(self) -> LRUAsyncCallable[AC]: + def __func__(self) -> LRUAsyncCallable[Any]: return self._lru + def __get__( + self: LRUAsyncBoundCallable[S, P, R], + instance: S2, + owner: Optional[type] = None, + ) -> LRUAsyncBoundCallable[S2, P, R]: + return LRUAsyncBoundCallable(self._lru, instance) + def __call__(self, *args, **kwargs): # type: ignore return self._lru(self.__self__, *args, **kwargs) @@ -289,22 +314,12 @@ def from_call( return cls(key) -def cache__get( - self: LRUAsyncCallable[AC], instance: object, owner: Optional[type] = None -) -> LRUAsyncCallable[AC]: - """Descriptor ``__get__`` for caches to bind them on lookup""" - if instance is None: - return self - return LRUAsyncBoundCallable(self, instance) - - class UncachedLRUAsyncCallable(LRUAsyncCallable[AC]): """Wrap the async ``call`` to track accesses as for caching/memoization""" __slots__ = ("__weakref__", "__dict__", "__wrapped__", "__misses", "__typed") __wrapped__: AC - __get__ = cache__get def __init__(self, call: AC, typed: bool): self.__wrapped__ = call # type: ignore[reportIncompatibleMethodOverride] @@ -342,7 +357,6 @@ class MemoizedLRUAsyncCallable(LRUAsyncCallable[AC]): ) __wrapped__: AC - __get__ = cache__get def __init__(self, call: AC, typed: bool): self.__wrapped__ = call # type: ignore[reportIncompatibleMethodOverride] @@ -397,7 +411,6 @@ class CachedLRUAsyncCallable(LRUAsyncCallable[AC]): ) __wrapped__: AC - __get__ = cache__get def __init__(self, call: AC, typed: bool, maxsize: int): self.__wrapped__ = call # type: ignore[reportIncompatibleMethodOverride] diff --git a/asyncstdlib/_lrucache.pyi b/asyncstdlib/_lrucache.pyi index 38794cc..08fb24c 100644 --- a/asyncstdlib/_lrucache.pyi +++ b/asyncstdlib/_lrucache.pyi @@ -1,37 +1,45 @@ -from ._typing import AC, Protocol, R as R, TypedDict from typing import ( + TypeVar, Any, Awaitable, Callable, + Generic, NamedTuple, - Optional, overload, + Protocol, ) +from typing_extensions import ParamSpec, Concatenate + +from ._typing import AC, TypedDict class CacheInfo(NamedTuple): hits: int misses: int - maxsize: Optional[int] + maxsize: int | None currsize: int class CacheParameters(TypedDict): - maxsize: Optional[int] + maxsize: int | None typed: bool +R = TypeVar("R") +P = ParamSpec("P") +S = TypeVar("S") +S2 = TypeVar("S2") + class LRUAsyncCallable(Protocol[AC]): + __slots__: tuple[str, ...] __call__: AC @overload def __get__( - self: LRUAsyncCallable[AC], - instance: None, - owner: Optional[type] = ..., + self: LRUAsyncCallable[AC], instance: None, owner: type | None = ... ) -> LRUAsyncCallable[AC]: ... @overload def __get__( - self: LRUAsyncCallable[Callable[..., Awaitable[R]]], - instance: object, - owner: Optional[type] = ..., - ) -> LRUAsyncBoundCallable[Callable[..., Awaitable[R]]]: ... + self: LRUAsyncCallable[Callable[Concatenate[S, P], Awaitable[R]]], + instance: S, + owner: type | None = ..., + ) -> LRUAsyncBoundCallable[S, P, R]: ... @property def __wrapped__(self) -> AC: ... def cache_parameters(self) -> CacheParameters: ... @@ -39,23 +47,32 @@ class LRUAsyncCallable(Protocol[AC]): def cache_clear(self) -> None: ... def cache_discard(self, *args: Any, **kwargs: Any) -> None: ... -class LRUAsyncBoundCallable(LRUAsyncCallable[AC]): - __self__: object - __call__: AC +class LRUAsyncBoundCallable(Generic[S, P, R]): + __slots__: tuple[str, ...] + __self__: S + __call__: Callable[P, Awaitable[R]] def __get__( - self: LRUAsyncBoundCallable[AC], - instance: Any, - owner: Optional[type] = ..., - ) -> LRUAsyncBoundCallable[AC]: ... - def __init__(self, lru: LRUAsyncCallable[AC], __self__: object) -> None: ... + self, instance: S2, owner: type | None = ... + ) -> LRUAsyncBoundCallable[S2, P, R]: ... + def __init__( + self, + lru: LRUAsyncCallable[Callable[Concatenate[S, P], Awaitable[R]]], + __self__: S, + ) -> None: ... @property - def __wrapped__(self) -> AC: ... + def __wrapped__(self) -> Callable[Concatenate[S, P], Awaitable[R]]: ... @property - def __func__(self) -> LRUAsyncCallable[AC]: ... + def __func__( + self, + ) -> LRUAsyncCallable[Callable[Concatenate[S, P], Awaitable[R]]]: ... + def cache_parameters(self) -> CacheParameters: ... + def cache_info(self) -> CacheInfo: ... + def cache_clear(self) -> None: ... + def cache_discard(self, *args: Any, **kwargs: Any) -> None: ... @overload def lru_cache(maxsize: AC, typed: bool = ...) -> LRUAsyncCallable[AC]: ... @overload def lru_cache( - maxsize: Optional[int] = ..., typed: bool = ... + maxsize: int | None = ..., typed: bool = ... ) -> Callable[[AC], LRUAsyncCallable[AC]]: ... diff --git a/asyncstdlib/_typing.py b/asyncstdlib/_typing.py index f386723..895e2d5 100644 --- a/asyncstdlib/_typing.py +++ b/asyncstdlib/_typing.py @@ -14,13 +14,14 @@ Callable, Any, Awaitable, + runtime_checkable, + Protocol, + ContextManager, + TypedDict, ) -from typing import Protocol, AsyncContextManager, ContextManager, TypedDict - __all__ = [ "Protocol", - "AsyncContextManager", "ContextManager", "TypedDict", "T", @@ -35,6 +36,8 @@ "HK", "LT", "ADD", + "AClose", + "ACloseable", "AnyIterable", ] @@ -70,5 +73,15 @@ def __add__(self: ADD, other: ADD) -> ADD: raise NotImplementedError +# await AClose.aclose() +AClose = TypeVar("AClose", bound="ACloseable") + + +@runtime_checkable +class ACloseable(Protocol): + async def aclose(self) -> None: + """Asynchronously close this object""" + + #: (async) iter T AnyIterable = Union[Iterable[T], AsyncIterable[T]] diff --git a/asyncstdlib/asynctools.py b/asyncstdlib/asynctools.py index 40dc243..5de673e 100644 --- a/asyncstdlib/asynctools.py +++ b/asyncstdlib/asynctools.py @@ -2,6 +2,7 @@ from functools import wraps from typing import ( Union, + AsyncContextManager, AsyncIterator, TypeVar, AsyncGenerator, @@ -14,7 +15,7 @@ Optional, ) -from ._typing import AsyncContextManager, T, T1, T2, T3, T4, T5, AnyIterable +from ._typing import T, T1, T2, T3, T4, T5, AnyIterable from ._core import aiter from .contextlib import nullcontext @@ -50,11 +51,11 @@ def __init__(self, iterator: Union[AsyncIterator[T], AsyncGenerator[T, S]]): self.__anext__ = self._wrapper.__anext__ # type: ignore if hasattr(iterator, "asend"): self.asend = ( - iterator.asend # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues] + iterator.asend # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] ) if hasattr(iterator, "athrow"): self.athrow = ( - iterator.athrow # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues] + iterator.athrow # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] ) def __aiter__(self) -> AsyncGenerator[T, S]: @@ -409,9 +410,9 @@ async def await_iter(n): async for item in iterable: yield ( item if not isinstance(item, Awaitable) else await item - ) # pyright: ignore[reportGeneralTypeIssues] + ) # pyright: ignore[reportReturnType] else: for item in iterable: yield ( item if not isinstance(item, Awaitable) else await item - ) # pyright: ignore[reportGeneralTypeIssues] + ) # pyright: ignore[reportReturnType] diff --git a/asyncstdlib/builtins.py b/asyncstdlib/builtins.py index 44e6b3e..2d35a0c 100644 --- a/asyncstdlib/builtins.py +++ b/asyncstdlib/builtins.py @@ -444,7 +444,7 @@ async def sorted( try: return _sync_builtins.sorted(iterable, reverse=reverse) # type: ignore except TypeError: - items: "_sync_builtins.list[Any]" = [item async for item in aiter(iterable)] + items: _sync_builtins.list[Any] = [item async for item in aiter(iterable)] items.sort(reverse=reverse) return items else: diff --git a/asyncstdlib/contextlib.py b/asyncstdlib/contextlib.py index c82bf42..a67e199 100644 --- a/asyncstdlib/contextlib.py +++ b/asyncstdlib/contextlib.py @@ -8,14 +8,14 @@ Any, Awaitable, Deque, - overload, + AsyncContextManager, ) from functools import wraps from collections import deque from functools import partial import sys -from ._typing import Protocol, AsyncContextManager, ContextManager, T, C +from ._typing import AClose, ContextManager, T, C from ._core import awaitify from ._utility import public_module @@ -28,14 +28,6 @@ AbstractContextManager = AsyncContextManager -class ACloseable(Protocol): - async def aclose(self) -> None: - """Asynchronously close this object""" - - -AC = TypeVar("AC", bound=ACloseable) - - def contextmanager( func: Callable[..., AsyncGenerator[T, None]] ) -> Callable[..., AsyncContextManager[T]]: @@ -126,7 +118,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: @public_module(__name__, "closing") -class Closing(Generic[AC]): +class Closing(Generic[AClose]): """ Create an :term:`asynchronous context manager` to ``aclose`` some ``thing`` on exit @@ -150,10 +142,10 @@ class Closing(Generic[AC]): is eventually closed and only :term:`borrowed ` until then. """ - def __init__(self, thing: AC): + def __init__(self, thing: AClose): self.thing = thing - async def __aenter__(self) -> AC: + async def __aenter__(self) -> AClose: return self.thing async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: @@ -165,7 +157,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: @public_module(__name__, "nullcontext") -class NullContext(Generic[T]): +class NullContext(AsyncContextManager[T]): """ Create an :term:`asynchronous context manager` that only returns ``enter_result`` @@ -190,22 +182,10 @@ async def safe_fetch(source): __slots__ = ("enter_result",) - @overload - def __init__(self: "NullContext[None]", enter_result: None = ...) -> None: ... - - @overload - def __init__(self: "NullContext[T]", enter_result: T) -> None: ... - - def __init__(self, enter_result: Optional[T] = None): + def __init__(self, enter_result: T = None): self.enter_result = enter_result - @overload - async def __aenter__(self: "NullContext[None]") -> None: ... - - @overload - async def __aenter__(self: "NullContext[T]") -> T: ... - - async def __aenter__(self) -> Optional[T]: + async def __aenter__(self) -> T: return self.enter_result async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: @@ -215,15 +195,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: nullcontext = NullContext -SE = TypeVar( - "SE", - bound=Union[ - AsyncContextManager[Any], - ContextManager[Any], - Callable[[Any, BaseException, Any], Optional[bool]], - Callable[[Any, BaseException, Any], Awaitable[Optional[bool]]], - ], -) +SE = TypeVar("SE") class ExitStack: diff --git a/asyncstdlib/contextlib.pyi b/asyncstdlib/contextlib.pyi new file mode 100644 index 0000000..e364b63 --- /dev/null +++ b/asyncstdlib/contextlib.pyi @@ -0,0 +1,79 @@ +from typing import ( + TypeVar, + Generic, + AsyncGenerator, + Callable, + Optional, + Any, + Awaitable, + overload, + AsyncContextManager, +) +from typing_extensions import ParamSpec, Self +from types import TracebackType + +from ._typing import AClose, ContextManager, T, R + +AnyContextManager = AsyncContextManager[T] | ContextManager[T] + +P = ParamSpec("P") + +def contextmanager( + func: Callable[P, AsyncGenerator[T, None]] +) -> Callable[P, AsyncContextManager[T]]: ... + +class closing(Generic[AClose]): + def __init__(self, thing: AClose) -> None: ... + async def __aenter__(self: Self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: ... + +class nullcontext(AsyncContextManager[T]): + enter_result: T + + @overload + def __init__(self: nullcontext[None], enter_result: None = ...) -> None: ... + @overload + def __init__(self: nullcontext[T], enter_result: T) -> None: ... + async def __aenter__(self: nullcontext[T]) -> T: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: ... + +SE = TypeVar( + "SE", + bound=AsyncContextManager[Any] + | ContextManager[Any] + | Callable[ + [type[BaseException] | None, BaseException | None, TracebackType | None], + Optional[bool], + ] + | Callable[ + [type[BaseException] | None, BaseException | None, TracebackType | None], + Awaitable[Optional[bool]], + ], +) + +class ExitStack: + def __init__(self) -> None: ... + def pop_all(self: Self) -> Self: ... + def push(self, exit: SE) -> SE: ... + def callback( + self, callback: Callable[P, R], *args: P.args, **kwargs: P.kwargs + ) -> Callable[P, R]: ... + async def enter_context(self, cm: AnyContextManager[T]) -> T: ... + async def aclose(self) -> None: ... + async def __aenter__(self: Self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + tb: TracebackType | None, + ) -> bool: ... diff --git a/asyncstdlib/functools.py b/asyncstdlib/functools.py index 212ff8a..d26b21a 100644 --- a/asyncstdlib/functools.py +++ b/asyncstdlib/functools.py @@ -10,12 +10,18 @@ overload, ) -from ._typing import T, T1, T2, AC, AnyIterable +from ._typing import T, AC, AnyIterable from ._core import ScopedIter, awaitify as _awaitify, Sentinel from .builtins import anext from ._utility import public_module -from ._lrucache import lru_cache, CacheInfo, CacheParameters, LRUAsyncCallable +from ._lrucache import ( + lru_cache, + CacheInfo, + CacheParameters, + LRUAsyncCallable, + LRUAsyncBoundCallable, +) __all__ = [ "cache", @@ -23,6 +29,7 @@ "CacheInfo", "CacheParameters", "LRUAsyncCallable", + "LRUAsyncBoundCallable", "reduce", "cached_property", ] @@ -38,9 +45,6 @@ def cache(user_function: AC) -> LRUAsyncCallable[AC]: return lru_cache(maxsize=None)(user_function) -__REDUCE_SENTINEL = Sentinel("") - - class AwaitableValue(Generic[T]): """Helper to provide an arbitrary value in ``await``""" @@ -147,6 +151,8 @@ def __get__( ) -> Union["CachedProperty[T]", Awaitable[T]]: if instance is None: return self + # __get__ may be called multiple times before it is first awaited to completion + # provide a placeholder that acts just like the final value does return _RepeatableCoroutine(self._get_attribute, instance) async def _get_attribute(self, instance: object) -> T: @@ -158,16 +164,7 @@ async def _get_attribute(self, instance: object) -> T: cached_property = CachedProperty -@overload -def reduce( - function: Callable[[T1, T2], T1], iterable: AnyIterable[T2], initial: T1 -) -> Coroutine[T1, Any, Any]: ... - - -@overload -def reduce( - function: Callable[[T, T], T], iterable: AnyIterable[T] -) -> Coroutine[T, Any, Any]: ... +__REDUCE_SENTINEL = Sentinel("") async def reduce( diff --git a/asyncstdlib/functools.pyi b/asyncstdlib/functools.pyi new file mode 100644 index 0000000..72fe8ab --- /dev/null +++ b/asyncstdlib/functools.pyi @@ -0,0 +1,26 @@ +from typing import Any, Awaitable, Callable, Generic, overload + +from ._typing import T, T1, T2, AC, AnyIterable + +from ._lrucache import ( + LRUAsyncCallable as LRUAsyncCallable, + LRUAsyncBoundCallable as LRUAsyncBoundCallable, + lru_cache as lru_cache, +) + +def cache(user_function: AC) -> LRUAsyncCallable[AC]: ... + +class cached_property(Generic[T]): + def __init__(self, getter: Callable[[Any], Awaitable[T]]) -> None: ... + def __set_name__(self, owner: Any, name: str) -> None: ... + @overload + def __get__(self, instance: None, owner: type) -> "cached_property[T]": ... + @overload + def __get__(self, instance: object, owner: type | None) -> Awaitable[T]: ... + +@overload +async def reduce( + function: Callable[[T1, T2], T1], iterable: AnyIterable[T2], initial: T1 +) -> T1: ... +@overload +async def reduce(function: Callable[[T, T], T], iterable: AnyIterable[T]) -> T: ... diff --git a/asyncstdlib/heapq.py b/asyncstdlib/heapq.py index 4397530..a48b637 100644 --- a/asyncstdlib/heapq.py +++ b/asyncstdlib/heapq.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import ( Generic, AsyncIterator, @@ -13,7 +14,7 @@ from .builtins import enumerate as a_enumerate, zip as a_zip from ._core import aiter, awaitify, ScopedIter, borrow -from ._typing import AnyIterable, LT, T +from ._typing import AnyIterable, ACloseable, LT, T class _KeyIter(Generic[LT]): @@ -92,36 +93,13 @@ async def pull_head(self) -> bool: self.head_key = await self.key(head) if self.key is not None else head return True - def __lt__(self, other: "_KeyIter[LT]") -> bool: + def __lt__(self, other: _KeyIter[LT]) -> bool: return self.reverse ^ (self.head_key < other.head_key) - def __eq__(self, other: "_KeyIter[LT]") -> bool: # type: ignore[override] + def __eq__(self, other: _KeyIter[LT]) -> bool: # type: ignore[override] return not (self.head_key < other.head_key or other.head_key < self.head_key) -@overload -def merge( - *iterables: AnyIterable[LT], key: None = ..., reverse: bool = ... -) -> AsyncIterator[LT]: - pass - - -@overload -def merge( - *iterables: AnyIterable[T], - key: Callable[[T], Awaitable[LT]] = ..., - reverse: bool = ..., -) -> AsyncIterator[T]: - pass - - -@overload -def merge( - *iterables: AnyIterable[T], key: Callable[[T], LT] = ..., reverse: bool = ... -) -> AsyncIterator[T]: - pass - - async def merge( *iterables: AnyIterable[Any], key: Optional[Callable[[Any], Any]] = None, @@ -172,7 +150,7 @@ async def merge( yield item finally: for itr, _ in iter_heap: - if hasattr(itr.tail, "aclose"): + if isinstance(itr.tail, ACloseable): await itr.tail.aclose() @@ -184,7 +162,7 @@ class ReverseLT(Generic[LT]): def __init__(self, key: LT): self.key = key - def __lt__(self, other: "ReverseLT[LT]") -> bool: + def __lt__(self, other: ReverseLT[LT]) -> bool: return other.key < self.key @@ -193,7 +171,7 @@ def __lt__(self, other: "ReverseLT[LT]") -> bool: # In other words, during search we maintain opposite sort order than what is requested. # We turn the min-heap into a max-sort in the end. async def _largest( - iterable: AsyncIterator[T], + iterable: AnyIterable[T], n: int, key: Callable[[T], Awaitable[LT]], reverse: bool, @@ -226,7 +204,7 @@ async def _identity(x: T) -> T: async def nlargest( - iterable: AsyncIterator[T], + iterable: AnyIterable[T], n: int, key: Optional[Callable[[Any], Awaitable[Any]]] = None, ) -> List[T]: @@ -248,7 +226,7 @@ async def nlargest( async def nsmallest( - iterable: AsyncIterator[T], + iterable: AnyIterable[T], n: int, key: Optional[Callable[[Any], Awaitable[Any]]] = None, ) -> List[T]: diff --git a/asyncstdlib/heapq.pyi b/asyncstdlib/heapq.pyi new file mode 100644 index 0000000..1006eb2 --- /dev/null +++ b/asyncstdlib/heapq.pyi @@ -0,0 +1,40 @@ +from typing import AsyncIterator, Awaitable, Callable, overload + +from ._typing import AnyIterable, T, LT + +@overload +def merge( + *iterables: AnyIterable[LT], key: None = ..., reverse: bool = ... +) -> AsyncIterator[LT]: ... +@overload +def merge( + *iterables: AnyIterable[T], key: Callable[[T], Awaitable[LT]], reverse: bool = ... +) -> AsyncIterator[T]: ... +@overload +def merge( + *iterables: AnyIterable[T], key: Callable[[T], LT], reverse: bool = ... +) -> AsyncIterator[T]: ... +@overload +async def nlargest( + iterable: AsyncIterator[LT], n: int, key: None = ... +) -> list[LT]: ... +@overload +async def nlargest( + iterable: AsyncIterator[T], n: int, key: Callable[[T], Awaitable[LT]] +) -> list[T]: ... +@overload +async def nlargest( + iterable: AsyncIterator[T], n: int, key: Callable[[T], LT] +) -> list[T]: ... +@overload +async def nsmallest( + iterable: AsyncIterator[LT], n: int, key: None = ... +) -> list[LT]: ... +@overload +async def nsmallest( + iterable: AsyncIterator[T], n: int, key: Callable[[T], Awaitable[LT]] +) -> list[T]: ... +@overload +async def nsmallest( + iterable: AsyncIterator[T], n: int, key: Callable[[T], LT] +) -> list[T]: ... diff --git a/asyncstdlib/itertools.py b/asyncstdlib/itertools.py index 2bd3948..3e47be7 100644 --- a/asyncstdlib/itertools.py +++ b/asyncstdlib/itertools.py @@ -1,6 +1,7 @@ from typing import ( Any, TypeVar, + AsyncContextManager, AsyncIterator, List, Awaitable, @@ -14,12 +15,10 @@ Tuple, overload, AsyncGenerator, - Protocol, - runtime_checkable, ) from collections import deque -from ._typing import T, R, T1, T2, T3, T4, T5, AnyIterable, ADD, AsyncContextManager +from ._typing import ACloseable, T, AnyIterable, ADD from ._utility import public_module from ._core import ( ScopedIter, @@ -72,30 +71,6 @@ async def add(x: ADD, y: ADD) -> ADD: return x + y -@overload -def accumulate(iterable: AnyIterable[ADD]) -> AsyncIterator[ADD]: ... - - -@overload -def accumulate(iterable: AnyIterable[ADD], *, initial: ADD) -> AsyncIterator[ADD]: ... - - -@overload -def accumulate( - iterable: AnyIterable[T], - function: Union[Callable[[T, T], T], Callable[[T, T], Awaitable[T]]], -) -> AsyncIterator[T]: ... - - -@overload -def accumulate( - iterable: AnyIterable[T], - function: Union[Callable[[T, T], T], Callable[[T, T], Awaitable[T]]], - *, - initial: T, -) -> AsyncIterator[T]: ... - - async def accumulate( iterable: AnyIterable[Any], function: Union[ @@ -158,12 +133,6 @@ async def batched(iterable: AnyIterable[T], n: int) -> AsyncIterator[Tuple[T, .. yield batch -@runtime_checkable -class _ACloseable(Protocol): - async def aclose(self) -> None: - """Asynchronously close this object""" - - class chain(AsyncIterator[T]): """ An :term:`asynchronous iterator` flattening values from all ``iterables`` @@ -196,7 +165,7 @@ def __init__( self._owned_iterators = tuple( iterable # type: ignore[misc] for iterable in iterables - if isinstance(iterable, AsyncIterator) and isinstance(iterable, _ACloseable) + if isinstance(iterable, AsyncIterator) and isinstance(iterable, ACloseable) ) @classmethod @@ -221,7 +190,7 @@ async def aclose(self) -> None: async def compress( - data: AnyIterable[T], selectors: AnyIterable[bool] + data: AnyIterable[T], selectors: AnyIterable[Any] ) -> AsyncIterator[T]: """ An :term:`asynchronous iterator` for items of ``data`` with true ``selectors`` @@ -242,8 +211,7 @@ async def compress(data, selectors): async def dropwhile( - predicate: Union[Callable[[T], bool], Callable[[T], Awaitable[bool]]], - iterable: AnyIterable[T], + predicate: Callable[[T], Any], iterable: AnyIterable[T] ) -> AsyncIterator[T]: """ Yield items from ``iterable`` after ``predicate(item)`` is no longer true @@ -271,10 +239,7 @@ async def filterfalse( """ Yield items from ``iterable`` for which ``predicate(item)`` is false. - If ``predicate`` is ``None``, return items which are false. - - Lazily iterates over ``iterable``, yielding only items for which - ``predicate`` of the current item is false. + If ``predicate`` is ``None``, yield any items which are false. """ async with ScopedIter(iterable) as async_iter: if predicate is None: @@ -342,8 +307,7 @@ async def starmap( async def takewhile( - predicate: Union[Callable[[T], bool], Callable[[T], Awaitable[bool]]], - iterable: AnyIterable[T], + predicate: Callable[[T], Any], iterable: AnyIterable[T] ) -> AsyncIterator[T]: """ Yield items from ``iterable`` as long as ``predicate(item)`` is true @@ -412,7 +376,7 @@ async def tee_peer( peers.pop(idx) break # if we are the last peer, try and close the iterator - if not peers and isinstance(iterator, _ACloseable): + if not peers and isinstance(iterator, ACloseable): await iterator.aclose() @@ -529,70 +493,6 @@ async def _repeat(value: T) -> AsyncIterator[T]: yield value -@overload -def zip_longest( - __it1: AnyIterable[T1], - *, - fillvalue: S = ..., -) -> AsyncIterator[Tuple[T1]]: ... - - -@overload -def zip_longest( - __it1: AnyIterable[T1], - __it2: AnyIterable[T2], - *, - fillvalue: S = ..., -) -> AsyncIterator[Tuple[Union[T1, S], Union[T2, S]]]: ... - - -@overload -def zip_longest( - __it1: AnyIterable[T1], - __it2: AnyIterable[T2], - __it3: AnyIterable[T3], - *, - fillvalue: S = ..., -) -> AsyncIterator[Tuple[Union[T1, S], Union[T2, S], Union[T3, S]]]: ... - - -@overload -def zip_longest( - __it1: AnyIterable[T1], - __it2: AnyIterable[T2], - __it3: AnyIterable[T3], - __it4: AnyIterable[T4], - *, - fillvalue: S = ..., -) -> AsyncIterator[Tuple[Union[T1, S], Union[T2, S], Union[T3, S], Union[T4, S]]]: ... - - -@overload -def zip_longest( - __it1: AnyIterable[T1], - __it2: AnyIterable[T2], - __it3: AnyIterable[T3], - __it4: AnyIterable[T4], - __it5: AnyIterable[T5], - *, - fillvalue: S = ..., -) -> AsyncIterator[ - Tuple[Union[T1, S], Union[T2, S], Union[T3, S], Union[T4, S], Union[T5, S]] -]: ... - - -@overload -def zip_longest( - __it1: AnyIterable[Any], - __it2: AnyIterable[Any], - __it3: AnyIterable[Any], - __it4: AnyIterable[Any], - __it5: AnyIterable[Any], - *iterables: AnyIterable[Any], - fillvalue: S = ..., -) -> AsyncIterator[Tuple[Any, ...]]: ... - - async def zip_longest( *iterables: AnyIterable[Any], fillvalue: Any = None ) -> AsyncIterator[Tuple[Any, ...]]: @@ -617,7 +517,7 @@ async def zip_longest( try: remaining = len(async_iters) while True: - values = [] + values: list[Any] = [] for index, aiterator in enumerate(async_iters): try: value = await anext(aiterator) @@ -634,12 +534,8 @@ async def zip_longest( finally: await fill_iter.aclose() # type: ignore for iterator in async_iters: - try: - aclose = iterator.aclose() # type: ignore - except AttributeError: - pass - else: - await aclose + if isinstance(iterator, ACloseable): + await iterator.aclose() async def identity(x: T) -> T: @@ -647,25 +543,7 @@ async def identity(x: T) -> T: return x -@overload # noqa: F811 -def groupby( # noqa: F811 - iterable: AnyIterable[T], -) -> AsyncIterator[Tuple[T, AsyncIterator[T]]]: ... - - -@overload # noqa: F811 -def groupby( # noqa: F811 - iterable: AnyIterable[T], key: None -) -> AsyncIterator[Tuple[T, AsyncIterator[T]]]: ... - - -@overload # noqa: F811 -def groupby( # noqa: F811 - iterable: AnyIterable[T], key: Union[Callable[[T], R], Callable[[T], Awaitable[R]]] -) -> AsyncIterator[Tuple[R, AsyncIterator[T]]]: ... - - -async def groupby( # noqa: F811 +async def groupby( iterable: AnyIterable[Any], key: Optional[ Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]]] @@ -693,7 +571,7 @@ async def groupby( # noqa: F811 # whether the current group was exhausted and the next begins already exhausted = False # `current_*`: buffer for key/value the current group peeked beyond its end - current_key = current_value = nothing = object() # type: Any + current_key = current_value = nothing = object() make_key: Callable[[Any], Awaitable[Any]] = ( _awaitify(key) if key is not None else identity # type: ignore ) diff --git a/asyncstdlib/itertools.pyi b/asyncstdlib/itertools.pyi new file mode 100644 index 0000000..d78c285 --- /dev/null +++ b/asyncstdlib/itertools.pyi @@ -0,0 +1,238 @@ +from typing import ( + Any, + AsyncIterator, + AsyncContextManager, + Awaitable, + Generic, + Iterator, + Iterable, + Callable, + TypeVar, + Self, + overload, +) +from typing_extensions import Literal + +from ._typing import AnyIterable, ADD, T, T1, T2, T3, T4, T5 + +def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]: ... +@overload +def accumulate(iterable: AnyIterable[ADD]) -> AsyncIterator[ADD]: ... +@overload +def accumulate(iterable: AnyIterable[ADD], *, initial: ADD) -> AsyncIterator[ADD]: ... +@overload +def accumulate( + iterable: AnyIterable[T], + function: Callable[[T, T], T] | Callable[[T, T], Awaitable[T]], +) -> AsyncIterator[T]: ... +@overload +def accumulate( + iterable: AnyIterable[T2], + function: Callable[[T1, T2], T1] | Callable[[T1, T2], Awaitable[T1]], + *, + initial: T1, +) -> AsyncIterator[T1]: ... +@overload +def batched(iterable: AnyIterable[T], n: Literal[1]) -> AsyncIterator[tuple[T]]: ... +@overload +def batched(iterable: AnyIterable[T], n: Literal[2]) -> AsyncIterator[tuple[T, T]]: ... +@overload +def batched( + iterable: AnyIterable[T], n: Literal[3] +) -> AsyncIterator[tuple[T, T, T]]: ... +@overload +def batched( + iterable: AnyIterable[T], n: Literal[4] +) -> AsyncIterator[tuple[T, T, T, T]]: ... +@overload +def batched( + iterable: AnyIterable[T], n: Literal[5] +) -> AsyncIterator[tuple[T, T, T, T, T]]: ... +@overload +def batched( + iterable: AnyIterable[T], n: Literal[6] +) -> AsyncIterator[tuple[T, T, T, T, T, T]]: ... +@overload +def batched(iterable: AnyIterable[T], n: int) -> AsyncIterator[tuple[T, ...]]: ... + +class chain(AsyncIterator[T]): + __slots__: tuple[str, ...] + def __init__(self, *iterables: AnyIterable[T]) -> None: ... + @classmethod + def from_iterable(cls, iterable: AnyIterable[AnyIterable[T]]) -> chain[T]: ... + async def __anext__(self) -> T: ... + async def aclose(self) -> None: ... + +def compress(data: AnyIterable[T], selectors: AnyIterable[Any]) -> AsyncIterator[T]: ... +async def dropwhile( + predicate: Callable[[T], Any], iterable: AnyIterable[T] +) -> AsyncIterator[T]: ... +async def filterfalse( + predicate: Callable[[T], Any] | None, iterable: AnyIterable[T] +) -> AsyncIterator[T]: ... +@overload +async def islice( + iterable: AnyIterable[T], start: int | None, / +) -> AsyncIterator[T]: ... +@overload +async def islice( + iterable: AnyIterable[T], + start: int | None, + stop: int | None, + step: int | None = None, + /, +) -> AsyncIterator[T]: ... +@overload +async def starmap( + function: Callable[[T1], T] | Callable[[T1], Awaitable[T]], + iterable: AnyIterable[tuple[T1]], +) -> AsyncIterator[T]: ... +@overload +async def starmap( + function: Callable[[T1, T2], T] | Callable[[T1, T2], Awaitable[T]], + iterable: AnyIterable[tuple[T1, T2]], +) -> AsyncIterator[T]: ... +@overload +async def starmap( + function: Callable[[T1, T2, T3], T] | Callable[[T1, T2, T3], Awaitable[T]], + iterable: AnyIterable[tuple[T1, T2, T3]], +) -> AsyncIterator[T]: ... +@overload +async def starmap( + function: Callable[[T1, T2, T3, T4], T] | Callable[[T1, T2, T3, T4], Awaitable[T]], + iterable: AnyIterable[tuple[T1, T2, T3, T4]], +) -> AsyncIterator[T]: ... +@overload +async def starmap( + function: ( + Callable[[T1, T2, T3, T4, T5], T] | Callable[[T1, T2, T3, T4, T5], Awaitable[T]] + ), + iterable: AnyIterable[tuple[T1, T2, T3, T4, T5]], +) -> AsyncIterator[T]: ... +@overload +async def starmap( + function: Callable[..., T] | Callable[..., Awaitable[T]], + iterable: AnyIterable[Iterable[Any]], +) -> AsyncIterator[T]: ... +def takewhile( + predicate: Callable[[T], Any], iterable: AnyIterable[T] +) -> AsyncIterator[T]: ... + +class tee(Generic[T]): + __slots__: tuple[str, ...] + + def __init__( + self, + iterable: AnyIterable[T], + n: int = ..., + *, + lock: AsyncContextManager[Any] | None = ..., + ) -> None: ... + def __len__(self) -> int: ... + @overload + def __getitem__(self, item: int) -> AsyncIterator[T]: ... + @overload + def __getitem__(self, item: slice) -> tuple[AsyncIterator[T], ...]: ... + def __iter__(self) -> Iterator[AnyIterable[T]]: ... + async def __aenter__(self: Self) -> Self: ... + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: ... + async def aclose(self) -> None: ... + +def pairwise(iterable: AnyIterable[T]) -> AsyncIterator[tuple[T, T]]: ... + +F = TypeVar("F") + +@overload +def zip_longest( + __it1: AnyIterable[T1], *, fillvalue: Any = ... +) -> AsyncIterator[tuple[T1]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T1], + __it2: AnyIterable[T2], +) -> AsyncIterator[tuple[T1 | None, T2 | None]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T1], + __it2: AnyIterable[T2], + *, + fillvalue: F, +) -> AsyncIterator[tuple[T1 | F, T2 | F]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T1], + __it2: AnyIterable[T2], + __it3: AnyIterable[T3], +) -> AsyncIterator[tuple[T1 | None, T2 | None, T3 | None]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T1], + __it2: AnyIterable[T2], + __it3: AnyIterable[T3], + *, + fillvalue: F, +) -> AsyncIterator[tuple[T1 | F, T2 | F, T3 | F]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T1], + __it2: AnyIterable[T2], + __it3: AnyIterable[T3], + __it4: AnyIterable[T4], +) -> AsyncIterator[tuple[T1 | None, T2 | None, T3 | None, T4 | None]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T1], + __it2: AnyIterable[T2], + __it3: AnyIterable[T3], + __it4: AnyIterable[T4], + *, + fillvalue: F, +) -> AsyncIterator[tuple[T1 | F, T2 | F, T3 | F, T4 | F]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T1], + __it2: AnyIterable[T2], + __it3: AnyIterable[T3], + __it4: AnyIterable[T4], + __it5: AnyIterable[T5], +) -> AsyncIterator[tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T1], + __it2: AnyIterable[T2], + __it3: AnyIterable[T3], + __it4: AnyIterable[T4], + __it5: AnyIterable[T5], + *, + fillvalue: F, +) -> AsyncIterator[tuple[T1 | F, T2 | F, T3 | F, T4 | F, T5 | F]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T], + __it2: AnyIterable[T], + __it3: AnyIterable[T], + __it4: AnyIterable[T], + __it5: AnyIterable[T], + *iterables: AnyIterable[T], +) -> AsyncIterator[tuple[T | None, ...]]: ... +@overload +def zip_longest( + __it1: AnyIterable[T], + __it2: AnyIterable[T], + __it3: AnyIterable[T], + __it4: AnyIterable[T], + __it5: AnyIterable[T], + *iterables: AnyIterable[T], + fillvalue: F, +) -> AsyncIterator[tuple[T | F, ...]]: ... + +K = TypeVar("K") + +@overload +def groupby( + iterable: AnyIterable[T], key: None = ... +) -> AsyncIterator[tuple[T, AsyncIterator[T]]]: ... +@overload +def groupby( + iterable: AnyIterable[T], key: Callable[[T], Awaitable[K]] | Callable[[T], K] +) -> AsyncIterator[tuple[K, AsyncIterator[T]]]: ... diff --git a/pyproject.toml b/pyproject.toml index 44edace..395a6f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ test = [ "pytest-cov", "flake8-2020", "mypy; implementation_name=='cpython'", + "typing-extensions", ] doc = ["sphinx", "sphinxcontrib-trio"] diff --git a/typetests/test_functools.py b/typetests/test_functools.py index 361971e..9b56bfd 100644 --- a/typetests/test_functools.py +++ b/typetests/test_functools.py @@ -15,9 +15,14 @@ class TestLRUMethod: """ Test that `lru_cache` works on methods """ + @lru_cache() - async def cached(self) -> int: - return 1 + async def cached(self, a: int = 0) -> int: + return a async def test_implicit_self(self) -> int: return await self.cached() + + async def test_method_parameters(self) -> int: + await self.cached("wrong parameter type") # type: ignore[arg-type] + return await self.cached(12)