From 8c96a1f0bf9c8a4b6aede6485363a20f4100c2c2 Mon Sep 17 00:00:00 2001 From: DABND19 Date: Mon, 6 May 2024 22:02:28 +0300 Subject: [PATCH 1/2] feat: Added typehints for aggregate and aggregate_async. --- aiomisc/aggregate.py | 45 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/aiomisc/aggregate.py b/aiomisc/aggregate.py index be4bb1c5..f447c470 100644 --- a/aiomisc/aggregate.py +++ b/aiomisc/aggregate.py @@ -4,7 +4,18 @@ from asyncio import CancelledError, Event, Future, Lock, wait_for from dataclasses import dataclass from inspect import Parameter -from typing import Any, Awaitable, Callable, Iterable, List, Optional, Union +from typing import ( + Any, + Awaitable, + Callable, + Generic, + Iterable, + List, + Optional, + Protocol, + TypeVar, + Union, +) from .compat import EventLoopMixin from .counters import Statistic @@ -13,10 +24,14 @@ log = logging.getLogger(__name__) +V = TypeVar("V") +R = TypeVar("R") + + @dataclass(frozen=True) -class Arg: - value: Any - future: Future +class Arg(Generic[V, R]): + value: V + future: "Future[R]" class ResultNotSetError(Exception): @@ -191,7 +206,23 @@ async def _execute(self, *, args: list, futures: List[Future]) -> None: future.set_exception(ResultNotSetError) -def aggregate(leeway_ms: float, max_count: Optional[int] = None) -> Callable: +S = TypeVar("S", contravariant=True) +T = TypeVar("T", covariant=True) + + +class AggregateFunc(Protocol, Generic[S, T]): + async def __call__(self, *args: S) -> Iterable[T]: + ... + + +class AggregateAsyncFunc(Protocol, Generic[V, R]): + async def __call__(self, *args: Arg[V, R]) -> None: + ... + + +def aggregate( + leeway_ms: float, max_count: Optional[int] = None +) -> Callable[[AggregateFunc[S, T]], Callable[[S], Awaitable[T]]]: """ Parametric decorator that aggregates multiple (but no more than ``max_count`` defaulting to ``None``) single-argument @@ -229,8 +260,8 @@ def decorator(func: AggFuncHighLevel) -> Callable[[Any], Awaitable]: def aggregate_async( - leeway_ms: float, max_count: Optional[int] = None, -) -> Callable: + leeway_ms: float, max_count: Optional[int] = None, +) -> Callable[[AggregateAsyncFunc[V, R]], Callable[[V], Awaitable[R]]]: """ Same as ``aggregate``, but with ``func`` arguments of type ``Arg`` containing ``value`` and ``future`` attributes instead. In this setting From a29c543c290a99967b3897d3ee2f15ee69999771 Mon Sep 17 00:00:00 2001 From: DABND19 Date: Tue, 7 May 2024 11:18:36 +0300 Subject: [PATCH 2/2] feat: Added annotations for Aggregator and AggregatorAsync. --- aiomisc/aggregate.py | 138 +++++++++++++++++++++++++------------------ 1 file changed, 82 insertions(+), 56 deletions(-) diff --git a/aiomisc/aggregate.py b/aiomisc/aggregate.py index f447c470..34d15f09 100644 --- a/aiomisc/aggregate.py +++ b/aiomisc/aggregate.py @@ -1,4 +1,5 @@ import asyncio +import functools import inspect import logging from asyncio import CancelledError, Event, Future, Lock, wait_for @@ -14,7 +15,6 @@ Optional, Protocol, TypeVar, - Union, ) from .compat import EventLoopMixin @@ -38,9 +38,11 @@ class ResultNotSetError(Exception): pass -AggFuncHighLevel = Callable[[Any], Awaitable[Iterable]] -AggFuncAsync = Callable[[Arg], Awaitable] -AggFunc = Union[AggFuncHighLevel, AggFuncAsync] +class AggregateAsyncFunc(Protocol, Generic[V, R]): + __name__: str + + async def __call__(self, *args: Arg[V, R]) -> None: + ... class AggregateStatistic(Statistic): @@ -51,27 +53,30 @@ class AggregateStatistic(Statistic): done: int -class Aggregator(EventLoopMixin): +def _has_variadic_positional(func: Callable[..., Any]) -> bool: + return any( + parameter.kind == Parameter.VAR_POSITIONAL + for parameter in inspect.signature(func).parameters.values() + ) - _func: AggFunc + +class AggregatorAsync(EventLoopMixin, Generic[V, R]): + + _func: AggregateAsyncFunc[V, R] _max_count: Optional[int] _leeway: float _first_call_at: Optional[float] _args: list - _futures: List[Future] + _futures: "List[Future[R]]" _event: Event _lock: Lock def __init__( - self, func: AggFunc, *, leeway_ms: float, + self, func: AggregateAsyncFunc[V, R], *, leeway_ms: float, max_count: Optional[int] = None, statistic_name: Optional[str] = None, ): - has_variadic_positional = any( - parameter.kind == Parameter.VAR_POSITIONAL - for parameter in inspect.signature(func).parameters.values() - ) - if not has_variadic_positional: + if not _has_variadic_positional(func): raise ValueError( "Function must accept variadic positional arguments", ) @@ -109,9 +114,18 @@ def leeway_ms(self) -> float: def count(self) -> int: return len(self._args) - async def _execute(self, *, args: list, futures: List[Future]) -> None: + async def _execute( + self, + *, + args: List[V], + futures: "List[Future[R]]", + ) -> None: + args_ = [ + Arg(value=arg, future=future) + for arg, future in zip(args, futures) + ] try: - results = await self._func(*args) + await self._func(*args_) self._statistic.success += 1 except CancelledError: # Other waiting tasks can try to finish the job instead. @@ -123,31 +137,29 @@ async def _execute(self, *, args: list, futures: List[Future]) -> None: finally: self._statistic.done += 1 - self._set_results(results, futures) - - def _set_results(self, results: Iterable, futures: List[Future]) -> None: - for future, result in zip(futures, results): + # Validate that all results/exceptions are set by the func + for future in futures: if not future.done(): - future.set_result(result) + future.set_exception(ResultNotSetError) def _set_exception( - self, exc: Exception, futures: List[Future], + self, exc: Exception, futures: List["Future[R]"], ) -> None: for future in futures: if not future.done(): future.set_exception(exc) - async def aggregate(self, arg: Any) -> Any: + async def aggregate(self, arg: V) -> R: if self._first_call_at is None: self._first_call_at = self.loop.time() first_call_at = self._first_call_at args: list = self._args - futures: List[Future] = self._futures + futures: "List[Future[R]]" = self._futures event: Event = self._event lock: Lock = self._lock args.append(arg) - future: Future = Future() + future: "Future[R]" = Future() futures.append(future) if self.count == self.max_count: @@ -180,49 +192,61 @@ async def aggregate(self, arg: Any) -> Any: return future.result() -class AggregatorAsync(Aggregator): - - async def _execute(self, *, args: list, futures: List[Future]) -> None: - args = [ - Arg(value=arg, future=future) - for arg, future in zip(args, futures) - ] - try: - await self._func(*args) - self._statistic.success += 1 - except CancelledError: - # Other waiting tasks can try to finish the job instead. - raise - except Exception as e: - self._set_exception(e, futures) - self._statistic.error += 1 - return - finally: - self._statistic.done += 1 - - # Validate that all results/exceptions are set by the func - for future in futures: - if not future.done(): - future.set_exception(ResultNotSetError) - - S = TypeVar("S", contravariant=True) T = TypeVar("T", covariant=True) class AggregateFunc(Protocol, Generic[S, T]): + __name__: str + async def __call__(self, *args: S) -> Iterable[T]: ... -class AggregateAsyncFunc(Protocol, Generic[V, R]): - async def __call__(self, *args: Arg[V, R]) -> None: - ... +def _to_async_aggregate(func: AggregateFunc[V, R]) -> AggregateAsyncFunc[V, R]: + @functools.wraps( + func, + assigned=tuple( + item + for item in functools.WRAPPER_ASSIGNMENTS + if item != "__annotations__" + ), + ) + async def wrapper(*args: Arg[V, R]) -> None: + args_ = [item.value for item in args] + results = await func(*args_) + for res, arg in zip(results, args): + if not arg.future.done(): + arg.future.set_result(res) + + return wrapper + + +class Aggregator(AggregatorAsync[V, R], Generic[V, R]): + def __init__( + self, + func: AggregateFunc[V, R], + *, + leeway_ms: float, + max_count: Optional[int] = None, + statistic_name: Optional[str] = None, + ) -> None: + if not _has_variadic_positional(func): + raise ValueError( + "Function must accept variadic positional arguments", + ) + + super().__init__( + _to_async_aggregate(func), + leeway_ms=leeway_ms, + max_count=max_count, + statistic_name=statistic_name, + ) def aggregate( leeway_ms: float, max_count: Optional[int] = None -) -> Callable[[AggregateFunc[S, T]], Callable[[S], Awaitable[T]]]: +) -> Callable[[AggregateFunc[V, R]], Callable[[V], Awaitable[R]]]: """ Parametric decorator that aggregates multiple (but no more than ``max_count`` defaulting to ``None``) single-argument @@ -251,7 +275,7 @@ def aggregate( :return: """ - def decorator(func: AggFuncHighLevel) -> Callable[[Any], Awaitable]: + def decorator(func: AggregateFunc[V, R]) -> Callable[[V], Awaitable[R]]: aggregator = Aggregator( func, max_count=max_count, leeway_ms=leeway_ms, ) @@ -272,7 +296,9 @@ def aggregate_async( :return: """ - def decorator(func: AggFuncAsync) -> Callable[[Any], Awaitable]: + def decorator( + func: AggregateAsyncFunc[V, R] + ) -> Callable[[V], Awaitable[R]]: aggregator = AggregatorAsync( func, max_count=max_count, leeway_ms=leeway_ms, )