Skip to content

Commit

Permalink
Merge pull request #155 from mjpieters/refacored_groupby
Browse files Browse the repository at this point in the history
Refactor groupby to use classes
  • Loading branch information
maxfischer2781 authored Aug 19, 2024
2 parents ce2d0d4 + 5d23d88 commit cb745e8
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 56 deletions.
174 changes: 123 additions & 51 deletions asyncstdlib/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
Iterable,
Iterator,
Tuple,
cast,
overload,
AsyncGenerator,
)
from collections import deque

from ._typing import ACloseable, T, AnyIterable, ADD
from ._typing import ACloseable, R, T, AnyIterable, ADD
from ._utility import public_module
from ._core import (
ScopedIter,
Expand All @@ -35,6 +36,7 @@
)

S = TypeVar("S")
T_co = TypeVar("T_co", covariant=True)


async def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]:
Expand Down Expand Up @@ -542,12 +544,86 @@ async def identity(x: T) -> T:
return x


async def groupby(
iterable: AnyIterable[Any],
key: Optional[
Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]]]
] = identity,
) -> AsyncIterator[Tuple[Any, AsyncIterator[Any]]]:
class _GroupByState(Generic[R, T_co]):
"""Internal state for the groupby iterator, shared between the parent and groups"""

__slots__ = (
"_iterator",
"_key_func",
"_current_value",
"target_key",
"current_key",
"current_group",
)

_sentinel = cast(T_co, object())

def __init__(
self, iterator: AsyncIterator[T_co], key_func: Callable[[T_co], Awaitable[R]]
):
self._iterator = iterator
self._key_func = key_func
self._current_value = self._sentinel

async def step(self) -> None:
# can raise StopAsyncIteration
value = await anext(self._iterator)
key = await self._key_func(value)
self._current_value, self.current_key = value, key

async def maybe_step(self) -> None:
"""Only step if there is no current value"""
if self._current_value is self._sentinel:
await self.step()

def consume_value(self) -> T_co:
"""Return the current value, after removing it from the current state"""
value, self._current_value = self._current_value, self._sentinel
return value

async def aclose(self) -> None:
"""Close the underlying iterator"""
if (group := self.current_group) is not None:
await group.aclose()
if isinstance(self._iterator, ACloseable):
await self._iterator.aclose()


class _Grouper(AsyncIterator[T_co], Generic[R, T_co]):
"""A single group iterator, part of a series of groups yielded by groupby"""

__slots__ = ("_target_key", "_state")

def __init__(self, target_key: R, state: "_GroupByState[R, T_co]") -> None:
self._target_key = target_key
self._state = state

async def __anext__(self) -> T_co:
state = self._state
if state.current_group is not self:
raise StopAsyncIteration

await state.maybe_step()
if self._target_key != state.current_key:
raise StopAsyncIteration

return state.consume_value()

async def aclose(self) -> None:
"""Close the group iterator
Note: this does _not_ close the underlying groupby managed iterator;
closing a single group shouldn't affect other groups in the series.
"""
state = self._state
if state.current_group is not self:
return
state.current_group = None


@public_module(__name__, "groupby")
class GroupBy(AsyncIterator[Tuple[R, AsyncIterator[T_co]]], Generic[R, T_co]):
"""
Create an async iterator over consecutive keys and groups from the (async) iterable
Expand All @@ -567,49 +643,45 @@ async def groupby(
required up-front for sorting, this loses the advantage of asynchronous,
lazy iteration and evaluation.
"""
# whether the current group was exhausted and the next begins already
exhausted = False
# `current_*`: buffer for key/value the current group peeked beyond its end
current_key = current_value = nothing = object()
make_key: Callable[[Any], Awaitable[Any]] = (
_awaitify(key) if key is not None else identity # type: ignore
)
async with ScopedIter(iterable) as async_iter:
# fast-forward mode: advance to the next group
async def seek_group() -> AsyncIterator[Any]:
nonlocal current_value, current_key, exhausted
# Note: `value` always ends up being some T
# - value is something: we can never unset it
# - value is `nothing`: the previous group was not exhausted,
# and we scan at least one new value
value: Any = current_value
if not exhausted:
previous_key = current_key
while previous_key == current_key:
value = await anext(async_iter)
current_key = await make_key(value)
current_value = nothing
exhausted = False
return group(current_key, value=value)

# the lazy iterable of all items with the same key
async def group(desired_key: Any, value: Any) -> AsyncIterator[Any]:
nonlocal current_value, current_key, exhausted
yield value
async for value in async_iter:
next_key: Any = await make_key(value)
if next_key == desired_key:
yield value
else:
exhausted = True
current_value = value
current_key = next_key
break

__slots__ = ("_state",)

def __init__(
self,
iterable: AnyIterable[T_co],
key: Optional[
Union[Callable[[T_co], R], Callable[[T_co], Awaitable[R]]]
] = None,
):
key_func = (
cast(Callable[[T_co], Awaitable[R]], identity)
if key is None
else _awaitify(key)
)
self._state = _GroupByState(aiter(iterable), key_func)

async def __anext__(self) -> Tuple[R, AsyncIterator[T_co]]:
state = self._state
# disable the last group to avoid concurrency
# issues.
state.current_group = None
await state.maybe_step()
try:
while True:
next_group = await seek_group()
async with ScopedIter(next_group) as scoped_group:
yield current_key, scoped_group
except StopAsyncIteration:
return
target_key = state.target_key
except AttributeError:
# no target key yet, skip scanning
pass
else:
# scan to the next group
while state.current_key == target_key:
await state.step()

state.target_key = current_key = state.current_key
state.current_group = group = _Grouper(current_key, state)
return (current_key, group)

async def aclose(self) -> None:
await self._state.aclose()


groupby = GroupBy
12 changes: 7 additions & 5 deletions asyncstdlib/itertools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,15 @@ def zip_longest(
fillvalue: F,
) -> AsyncIterator[tuple[T | F, ...]]: ...

K = TypeVar("K")
K_co = TypeVar("K_co", covariant=True)
T_co = TypeVar("T_co", covariant=True)

@overload
def groupby(
iterable: AnyIterable[T], key: None = ...
) -> AsyncIterator[tuple[T, AsyncIterator[T]]]: ...
iterable: AnyIterable[T_co], key: None = ...
) -> AsyncIterator[tuple[T_co, AsyncIterator[T_co]]]: ...
@overload
def groupby(
iterable: AnyIterable[T], key: Callable[[T], Awaitable[K]] | Callable[[T], K]
) -> AsyncIterator[tuple[K, AsyncIterator[T]]]: ...
iterable: AnyIterable[T_co],
key: Callable[[T_co], Awaitable[K_co]] | Callable[[T], K_co],
) -> AsyncIterator[tuple[K_co, AsyncIterator[T_co]]]: ...

0 comments on commit cb745e8

Please sign in to comment.