Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete itertools.chain interface #108

Merged
merged 13 commits into from
Apr 16, 2023
49 changes: 35 additions & 14 deletions asyncstdlib/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,33 +148,54 @@ class chain(AsyncIterator[T]):
The resulting iterator consecutively iterates over and yields all values from
each of the ``iterables``. This is similar to converting all ``iterables`` to
sequences and concatenating them, but lazily exhausts each iterable.

The ``chain`` assumes ownership of its ``iterables`` and closes them reliably
when the ``chain`` is closed. Pass the ``iterables`` via a :py:class:`tuple` to
``chain.from_iterable`` to avoid closing all iterables but those already processed.
"""

__slots__ = ("_impl",)
__slots__ = ("_iterator", "_owned_iterators")

def __init__(self, *iterables: AnyIterable[T]):
async def impl() -> AsyncIterator[T]:
for iterable in iterables:
@staticmethod
async def _chain_iterator(
any_iterables: AnyIterable[AnyIterable[T]],
) -> AsyncGenerator[T, None]:
async with ScopedIter(any_iterables) as iterables:
async for iterable in iterables:
async with ScopedIter(iterable) as iterator:
async for item in iterator:
yield item

self._impl = impl()
def __init__(
self, *iterables: AnyIterable[T], _iterables: AnyIterable[AnyIterable[T]] = ()
):
self._iterator = self._chain_iterator(iterables or _iterables)
self._owned_iterators = (
iterable
for iterable in iterables
if isinstance(iterable, AsyncIterator) and hasattr(iterable, "aclose")
)

@staticmethod
async def from_iterable(iterable: AnyIterable[AnyIterable[T]]) -> AsyncIterator[T]:
@classmethod
def from_iterable(cls, iterable: AnyIterable[AnyIterable[T]]) -> "chain[T]":
"""
Alternate constructor for :py:func:`~.chain` that lazily exhausts
iterables as well
the ``iterable`` of iterables as well

This is suitable for chaining iterables from a lazy or infinite ``iterable``.
In turn, closing the ``chain`` only closes those iterables
already fetched from ``iterable``.
"""
async with ScopedIter(iterable) as iterables:
async for sub_iterable in iterables:
async with ScopedIter(sub_iterable) as iterator:
async for item in iterator:
yield item
return cls(_iterables=iterable)

def __anext__(self) -> Awaitable[T]:
return self._impl.__anext__()
return self._iterator.__anext__()

async def aclose(self) -> None:
for iterable in self._owned_iterators:
if hasattr(iterable, "aclose"):
await iterable.aclose()
await self._iterator.aclose()


async def compress(
Expand Down
51 changes: 51 additions & 0 deletions unittests/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,57 @@ async def test_chain(iterables):
)


class ACloseFacade:
"""Wrapper to check if an iterator has been closed"""

def __init__(self, iterable):
self.closed = False
self.__wrapped__ = iterable
self._iterator = a.iter(iterable)

async def __anext__(self):
if self.closed:
raise StopAsyncIteration()
return await self._iterator.__anext__()

def __aiter__(self):
return self

async def aclose(self):
if hasattr(self._iterator, "aclose"):
await self._iterator.aclose()
self.closed = True


@pytest.mark.parametrize("iterables", chains)
@sync
async def test_chain_close_auto(iterables):
"""Test that `chain` closes exhausted iterators"""
closeable_iterables = [ACloseFacade(iterable) for iterable in iterables]
assert await a.list(a.chain(*closeable_iterables)) == list(
itertools.chain(*iterables)
)
assert all(iterable.closed for iterable in closeable_iterables)


# insert a known filled iterable since chain closes all that are exhausted
@pytest.mark.parametrize("iterables", [([1], *chain) for chain in chains])
@pytest.mark.parametrize(
"chain_type, must_close",
[(lambda iterators: a.chain(*iterators), True), (a.chain.from_iterable, False)],
)
@sync
async def test_chain_close_partial(iterables, chain_type, must_close):
"""Test that `chain` closes owned iterators"""
closeable_iterables = [ACloseFacade(iterable) for iterable in iterables]
chain = chain_type(closeable_iterables)
assert await a.anext(chain) == next(itertools.chain(*iterables))
await chain.aclose()
assert all(iterable.closed == must_close for iterable in closeable_iterables[1:])
# closed chain must remain closed regardless of iterators
assert await a.anext(chain, "sentinel") == "sentinel"


compress_cases = [
(range(20), [idx % 2 for idx in range(20)]),
([1] * 5, [True, True, False, True, True]),
Expand Down