From 5d23d88ef1b621878aa17c11732ae0fbd59fb2ab Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Thu, 8 Aug 2024 14:55:07 +0100 Subject: [PATCH] Refactor groupby to use classes The implementation is closely modelled on the cpython C implementation of the groupby iterator. --- asyncstdlib/itertools.py | 174 +++++++++++++++++++++++++++----------- asyncstdlib/itertools.pyi | 12 +-- 2 files changed, 130 insertions(+), 56 deletions(-) diff --git a/asyncstdlib/itertools.py b/asyncstdlib/itertools.py index f0e09ef..f00c196 100644 --- a/asyncstdlib/itertools.py +++ b/asyncstdlib/itertools.py @@ -13,12 +13,13 @@ Iterable, Iterator, Tuple, + cast, overload, AsyncGenerator, ) from collections import deque -from ._typing import ACloseable, T, AnyIterable, ADD +from ._typing import ACloseable, R, T, AnyIterable, ADD from ._utility import public_module from ._core import ( ScopedIter, @@ -35,6 +36,7 @@ ) S = TypeVar("S") +T_co = TypeVar("T_co", covariant=True) async def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]: @@ -542,12 +544,86 @@ async def identity(x: T) -> T: return x -async def groupby( - iterable: AnyIterable[Any], - key: Optional[ - Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]]] - ] = identity, -) -> AsyncIterator[Tuple[Any, AsyncIterator[Any]]]: +class _GroupByState(Generic[R, T_co]): + """Internal state for the groupby iterator, shared between the parent and groups""" + + __slots__ = ( + "_iterator", + "_key_func", + "_current_value", + "target_key", + "current_key", + "current_group", + ) + + _sentinel = cast(T_co, object()) + + def __init__( + self, iterator: AsyncIterator[T_co], key_func: Callable[[T_co], Awaitable[R]] + ): + self._iterator = iterator + self._key_func = key_func + self._current_value = self._sentinel + + async def step(self) -> None: + # can raise StopAsyncIteration + value = await anext(self._iterator) + key = await self._key_func(value) + self._current_value, self.current_key = value, key + + async def maybe_step(self) -> None: + """Only step if there is no current value""" + if self._current_value is self._sentinel: + await self.step() + + def consume_value(self) -> T_co: + """Return the current value, after removing it from the current state""" + value, self._current_value = self._current_value, self._sentinel + return value + + async def aclose(self) -> None: + """Close the underlying iterator""" + if (group := self.current_group) is not None: + await group.aclose() + if isinstance(self._iterator, ACloseable): + await self._iterator.aclose() + + +class _Grouper(AsyncIterator[T_co], Generic[R, T_co]): + """A single group iterator, part of a series of groups yielded by groupby""" + + __slots__ = ("_target_key", "_state") + + def __init__(self, target_key: R, state: "_GroupByState[R, T_co]") -> None: + self._target_key = target_key + self._state = state + + async def __anext__(self) -> T_co: + state = self._state + if state.current_group is not self: + raise StopAsyncIteration + + await state.maybe_step() + if self._target_key != state.current_key: + raise StopAsyncIteration + + return state.consume_value() + + async def aclose(self) -> None: + """Close the group iterator + + Note: this does _not_ close the underlying groupby managed iterator; + closing a single group shouldn't affect other groups in the series. + + """ + state = self._state + if state.current_group is not self: + return + state.current_group = None + + +@public_module(__name__, "groupby") +class GroupBy(AsyncIterator[Tuple[R, AsyncIterator[T_co]]], Generic[R, T_co]): """ Create an async iterator over consecutive keys and groups from the (async) iterable @@ -567,49 +643,45 @@ async def groupby( required up-front for sorting, this loses the advantage of asynchronous, lazy iteration and evaluation. """ - # whether the current group was exhausted and the next begins already - exhausted = False - # `current_*`: buffer for key/value the current group peeked beyond its end - current_key = current_value = nothing = object() - make_key: Callable[[Any], Awaitable[Any]] = ( - _awaitify(key) if key is not None else identity # type: ignore - ) - async with ScopedIter(iterable) as async_iter: - # fast-forward mode: advance to the next group - async def seek_group() -> AsyncIterator[Any]: - nonlocal current_value, current_key, exhausted - # Note: `value` always ends up being some T - # - value is something: we can never unset it - # - value is `nothing`: the previous group was not exhausted, - # and we scan at least one new value - value: Any = current_value - if not exhausted: - previous_key = current_key - while previous_key == current_key: - value = await anext(async_iter) - current_key = await make_key(value) - current_value = nothing - exhausted = False - return group(current_key, value=value) - - # the lazy iterable of all items with the same key - async def group(desired_key: Any, value: Any) -> AsyncIterator[Any]: - nonlocal current_value, current_key, exhausted - yield value - async for value in async_iter: - next_key: Any = await make_key(value) - if next_key == desired_key: - yield value - else: - exhausted = True - current_value = value - current_key = next_key - break + __slots__ = ("_state",) + + def __init__( + self, + iterable: AnyIterable[T_co], + key: Optional[ + Union[Callable[[T_co], R], Callable[[T_co], Awaitable[R]]] + ] = None, + ): + key_func = ( + cast(Callable[[T_co], Awaitable[R]], identity) + if key is None + else _awaitify(key) + ) + self._state = _GroupByState(aiter(iterable), key_func) + + async def __anext__(self) -> Tuple[R, AsyncIterator[T_co]]: + state = self._state + # disable the last group to avoid concurrency + # issues. + state.current_group = None + await state.maybe_step() try: - while True: - next_group = await seek_group() - async with ScopedIter(next_group) as scoped_group: - yield current_key, scoped_group - except StopAsyncIteration: - return + target_key = state.target_key + except AttributeError: + # no target key yet, skip scanning + pass + else: + # scan to the next group + while state.current_key == target_key: + await state.step() + + state.target_key = current_key = state.current_key + state.current_group = group = _Grouper(current_key, state) + return (current_key, group) + + async def aclose(self) -> None: + await self._state.aclose() + + +groupby = GroupBy diff --git a/asyncstdlib/itertools.pyi b/asyncstdlib/itertools.pyi index 699aed7..a525b6a 100644 --- a/asyncstdlib/itertools.pyi +++ b/asyncstdlib/itertools.pyi @@ -223,13 +223,15 @@ def zip_longest( fillvalue: F, ) -> AsyncIterator[tuple[T | F, ...]]: ... -K = TypeVar("K") +K_co = TypeVar("K_co", covariant=True) +T_co = TypeVar("T_co", covariant=True) @overload def groupby( - iterable: AnyIterable[T], key: None = ... -) -> AsyncIterator[tuple[T, AsyncIterator[T]]]: ... + iterable: AnyIterable[T_co], key: None = ... +) -> AsyncIterator[tuple[T_co, AsyncIterator[T_co]]]: ... @overload def groupby( - iterable: AnyIterable[T], key: Callable[[T], Awaitable[K]] | Callable[[T], K] -) -> AsyncIterator[tuple[K, AsyncIterator[T]]]: ... + iterable: AnyIterable[T_co], + key: Callable[[T_co], Awaitable[K_co]] | Callable[[T], K_co], +) -> AsyncIterator[tuple[K_co, AsyncIterator[T_co]]]: ...