Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements and cleanup for typing #129

Merged
merged 24 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions asyncstdlib/_lrucache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"""

from typing import (
Generic,
TypeVar,
NamedTuple,
Callable,
Any,
Expand Down Expand Up @@ -72,6 +74,14 @@ def __wrapped__(self) -> AC:
"""The callable wrapped by this cache"""
raise NotImplementedError

def __get__(
self: "LRUAsyncCallable[Any]", instance: object, owner: Optional[type] = None
maxfischer2781 marked this conversation as resolved.
Show resolved Hide resolved
) -> Any:
"""Descriptor ``__get__`` for caches to bind them on lookup"""
if instance is None:
return self
return LRUAsyncBoundCallable(self, instance)

#: Get the result of ``await __wrapped__(...)`` from the cache or evaluation
__call__: AC

Expand Down Expand Up @@ -106,23 +116,37 @@ def cache_discard(self, *args: Any, **kwargs: Any) -> None:
...


class LRUAsyncBoundCallable(LRUAsyncCallable[AC]):
# these are fake and only exist for placeholders
S = TypeVar("S")
S2 = TypeVar("S2")
P = TypeVar("P")
R = TypeVar("R")


class LRUAsyncBoundCallable(Generic[S, P, R]): # type: ignore[reportInvalidTypeVarUse]
"""A :py:class:`~.LRUAsyncCallable` that is bound like a method"""

__slots__ = ("__weakref__", "_lru", "__self__")

def __init__(self, lru: LRUAsyncCallable[AC], __self__: object):
def __init__(self, lru: LRUAsyncCallable[Any], __self__: object):
self._lru = lru
self.__self__ = __self__

@property
def __wrapped__(self) -> AC:
def __wrapped__(self) -> Any:
return self._lru.__wrapped__

@property
def __func__(self) -> LRUAsyncCallable[AC]:
def __func__(self) -> LRUAsyncCallable[Any]:
return self._lru

def __get__(
self: "LRUAsyncBoundCallable[S, P, R]",
instance: S2,
owner: Optional[type] = None,
) -> "LRUAsyncBoundCallable[S2, P, R]":
maxfischer2781 marked this conversation as resolved.
Show resolved Hide resolved
return LRUAsyncBoundCallable(self._lru, instance)

def __call__(self, *args, **kwargs): # type: ignore
return self._lru(self.__self__, *args, **kwargs)

Expand Down Expand Up @@ -289,22 +313,12 @@ def from_call(
return cls(key)


def cache__get(
self: LRUAsyncCallable[AC], instance: object, owner: Optional[type] = None
) -> LRUAsyncCallable[AC]:
"""Descriptor ``__get__`` for caches to bind them on lookup"""
if instance is None:
return self
return LRUAsyncBoundCallable(self, instance)


class UncachedLRUAsyncCallable(LRUAsyncCallable[AC]):
"""Wrap the async ``call`` to track accesses as for caching/memoization"""

__slots__ = ("__weakref__", "__dict__", "__wrapped__", "__misses", "__typed")

__wrapped__: AC
__get__ = cache__get

def __init__(self, call: AC, typed: bool):
self.__wrapped__ = call # type: ignore[reportIncompatibleMethodOverride]
Expand Down Expand Up @@ -342,7 +356,6 @@ class MemoizedLRUAsyncCallable(LRUAsyncCallable[AC]):
)

__wrapped__: AC
__get__ = cache__get

def __init__(self, call: AC, typed: bool):
self.__wrapped__ = call # type: ignore[reportIncompatibleMethodOverride]
Expand Down Expand Up @@ -397,7 +410,6 @@ class CachedLRUAsyncCallable(LRUAsyncCallable[AC]):
)

__wrapped__: AC
__get__ = cache__get

def __init__(self, call: AC, typed: bool, maxsize: int):
self.__wrapped__ = call # type: ignore[reportIncompatibleMethodOverride]
Expand Down
61 changes: 39 additions & 22 deletions asyncstdlib/_lrucache.pyi
Original file line number Diff line number Diff line change
@@ -1,61 +1,78 @@
from ._typing import AC, Protocol, R as R, TypedDict
from typing import (
TypeVar,
Any,
Awaitable,
Callable,
Generic,
NamedTuple,
Optional,
overload,
Protocol,
)
from typing_extensions import ParamSpec, Concatenate

from ._typing import AC, TypedDict

class CacheInfo(NamedTuple):
hits: int
misses: int
maxsize: Optional[int]
maxsize: int | None
currsize: int

class CacheParameters(TypedDict):
maxsize: Optional[int]
maxsize: int | None
typed: bool

R = TypeVar("R")
P = ParamSpec("P")
S = TypeVar("S")
S2 = TypeVar("S2")

class LRUAsyncCallable(Protocol[AC]):
__slots__: tuple[str, ...]
__call__: AC
@overload
def __get__(
self: LRUAsyncCallable[AC],
instance: None,
owner: Optional[type] = ...,
self: LRUAsyncCallable[AC], instance: None, owner: type | None = ...
) -> LRUAsyncCallable[AC]: ...
@overload
def __get__(
self: LRUAsyncCallable[Callable[..., Awaitable[R]]],
instance: object,
owner: Optional[type] = ...,
) -> LRUAsyncBoundCallable[Callable[..., Awaitable[R]]]: ...
self: LRUAsyncCallable[Callable[Concatenate[S, P], Awaitable[R]]],
instance: S,
owner: type | None = ...,
) -> LRUAsyncBoundCallable[S, P, R]: ...
@property
def __wrapped__(self) -> AC: ...
def cache_parameters(self) -> CacheParameters: ...
def cache_info(self) -> CacheInfo: ...
def cache_clear(self) -> None: ...
def cache_discard(self, *args: Any, **kwargs: Any) -> None: ...

class LRUAsyncBoundCallable(LRUAsyncCallable[AC]):
__self__: object
__call__: AC
class LRUAsyncBoundCallable(Generic[S, P, R]):
__slots__: tuple[str, ...]
__self__: S
__call__: Callable[P, Awaitable[R]]
def __get__(
self: LRUAsyncBoundCallable[AC],
instance: Any,
owner: Optional[type] = ...,
) -> LRUAsyncBoundCallable[AC]: ...
def __init__(self, lru: LRUAsyncCallable[AC], __self__: object) -> None: ...
self, instance: S2, owner: type | None = ...
) -> LRUAsyncBoundCallable[S2, P, R]: ...
def __init__(
self,
lru: LRUAsyncCallable[Callable[Concatenate[S, P], Awaitable[R]]],
__self__: S,
) -> None: ...
@property
def __wrapped__(self) -> AC: ...
def __wrapped__(self) -> Callable[Concatenate[S, P], Awaitable[R]]: ...
@property
def __func__(self) -> LRUAsyncCallable[AC]: ...
def __func__(
self,
) -> LRUAsyncCallable[Callable[Concatenate[S, P], Awaitable[R]]]: ...
def cache_parameters(self) -> CacheParameters: ...
def cache_info(self) -> CacheInfo: ...
def cache_clear(self) -> None: ...
def cache_discard(self, *args: Any, **kwargs: Any) -> None: ...

@overload
def lru_cache(maxsize: AC, typed: bool = ...) -> LRUAsyncCallable[AC]: ...
@overload
def lru_cache(
maxsize: Optional[int] = ..., typed: bool = ...
maxsize: int | None = ..., typed: bool = ...
) -> Callable[[AC], LRUAsyncCallable[AC]]: ...
19 changes: 16 additions & 3 deletions asyncstdlib/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
Callable,
Any,
Awaitable,
runtime_checkable,
Protocol,
ContextManager,
TypedDict,
)

from typing import Protocol, AsyncContextManager, ContextManager, TypedDict

__all__ = [
"Protocol",
"AsyncContextManager",
"ContextManager",
"TypedDict",
"T",
Expand All @@ -35,6 +36,8 @@
"HK",
"LT",
"ADD",
"AClose",
"ACloseable",
"AnyIterable",
]

Expand Down Expand Up @@ -70,5 +73,15 @@ def __add__(self: ADD, other: ADD) -> ADD:
raise NotImplementedError


# await AClose.aclose()
AClose = TypeVar("AClose", bound="ACloseable")
Fixed Show fixed Hide fixed


@runtime_checkable
class ACloseable(Protocol):
async def aclose(self) -> None:
"""Asynchronously close this object"""


#: (async) iter T
AnyIterable = Union[Iterable[T], AsyncIterable[T]]
11 changes: 6 additions & 5 deletions asyncstdlib/asynctools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from typing import (
Union,
AsyncContextManager,
AsyncIterator,
TypeVar,
AsyncGenerator,
Expand All @@ -14,7 +15,7 @@
Optional,
)

from ._typing import AsyncContextManager, T, T1, T2, T3, T4, T5, AnyIterable
from ._typing import T, T1, T2, T3, T4, T5, AnyIterable
from ._core import aiter
from .contextlib import nullcontext

Expand Down Expand Up @@ -50,11 +51,11 @@ def __init__(self, iterator: Union[AsyncIterator[T], AsyncGenerator[T, S]]):
self.__anext__ = self._wrapper.__anext__ # type: ignore
if hasattr(iterator, "asend"):
self.asend = (
iterator.asend # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues]
iterator.asend # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
)
if hasattr(iterator, "athrow"):
self.athrow = (
iterator.athrow # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues]
iterator.athrow # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
)

def __aiter__(self) -> AsyncGenerator[T, S]:
Expand Down Expand Up @@ -409,9 +410,9 @@ async def await_iter(n):
async for item in iterable:
yield (
item if not isinstance(item, Awaitable) else await item
) # pyright: ignore[reportGeneralTypeIssues]
) # pyright: ignore[reportReturnType]
else:
for item in iterable:
yield (
item if not isinstance(item, Awaitable) else await item
) # pyright: ignore[reportGeneralTypeIssues]
) # pyright: ignore[reportReturnType]
29 changes: 7 additions & 22 deletions asyncstdlib/contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
Awaitable,
Deque,
overload,
AsyncContextManager,
)
from functools import wraps
from collections import deque
from functools import partial
import sys

from ._typing import Protocol, AsyncContextManager, ContextManager, T, C
from ._typing import AClose, ContextManager, T, C
from ._core import awaitify
from ._utility import public_module

Expand All @@ -28,14 +29,6 @@
AbstractContextManager = AsyncContextManager


class ACloseable(Protocol):
async def aclose(self) -> None:
"""Asynchronously close this object"""


AC = TypeVar("AC", bound=ACloseable)


def contextmanager(
func: Callable[..., AsyncGenerator[T, None]]
) -> Callable[..., AsyncContextManager[T]]:
Expand Down Expand Up @@ -126,7 +119,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:


@public_module(__name__, "closing")
class Closing(Generic[AC]):
class Closing(Generic[AClose]):
"""
Create an :term:`asynchronous context manager` to ``aclose`` some ``thing`` on exit

Expand All @@ -150,10 +143,10 @@ class Closing(Generic[AC]):
is eventually closed and only :term:`borrowed <borrowing>` until then.
"""

def __init__(self, thing: AC):
def __init__(self, thing: AClose):
self.thing = thing

async def __aenter__(self) -> AC:
async def __aenter__(self) -> AClose:
return self.thing

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
Expand All @@ -165,7 +158,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:


@public_module(__name__, "nullcontext")
class NullContext(Generic[T]):
class NullContext(AsyncContextManager[T]):
"""
Create an :term:`asynchronous context manager` that only returns ``enter_result``

Expand Down Expand Up @@ -215,15 +208,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
nullcontext = NullContext


SE = TypeVar(
"SE",
bound=Union[
AsyncContextManager[Any],
ContextManager[Any],
Callable[[Any, BaseException, Any], Optional[bool]],
Callable[[Any, BaseException, Any], Awaitable[Optional[bool]]],
],
)
SE = TypeVar("SE")


class ExitStack:
Expand Down
Loading
Loading