Skip to content

Commit

Permalink
Complete itertools.chain interface (#108)
Browse files Browse the repository at this point in the history
* added chain.aclose method for cleanup (closes #107)

* do not reconstruct chain implementation again and again

* use same implementation for chain and chain.from_iterable

* chain owns explicitly passed iterables
  • Loading branch information
maxfischer2781 authored Apr 16, 2023
1 parent d20a48b commit b966a95
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 14 deletions.
49 changes: 35 additions & 14 deletions asyncstdlib/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,33 +148,54 @@ class chain(AsyncIterator[T]):
The resulting iterator consecutively iterates over and yields all values from
each of the ``iterables``. This is similar to converting all ``iterables`` to
sequences and concatenating them, but lazily exhausts each iterable.
The ``chain`` assumes ownership of its ``iterables`` and closes them reliably
when the ``chain`` is closed. Pass the ``iterables`` via a :py:class:`tuple` to
``chain.from_iterable`` to avoid closing all iterables but those already processed.
"""

__slots__ = ("_impl",)
__slots__ = ("_iterator", "_owned_iterators")

def __init__(self, *iterables: AnyIterable[T]):
async def impl() -> AsyncIterator[T]:
for iterable in iterables:
@staticmethod
async def _chain_iterator(
any_iterables: AnyIterable[AnyIterable[T]],
) -> AsyncGenerator[T, None]:
async with ScopedIter(any_iterables) as iterables:
async for iterable in iterables:
async with ScopedIter(iterable) as iterator:
async for item in iterator:
yield item

self._impl = impl()
def __init__(
self, *iterables: AnyIterable[T], _iterables: AnyIterable[AnyIterable[T]] = ()
):
self._iterator = self._chain_iterator(iterables or _iterables)
self._owned_iterators = (
iterable
for iterable in iterables
if isinstance(iterable, AsyncIterator) and hasattr(iterable, "aclose")
)

@staticmethod
async def from_iterable(iterable: AnyIterable[AnyIterable[T]]) -> AsyncIterator[T]:
@classmethod
def from_iterable(cls, iterable: AnyIterable[AnyIterable[T]]) -> "chain[T]":
"""
Alternate constructor for :py:func:`~.chain` that lazily exhausts
iterables as well
the ``iterable`` of iterables as well
This is suitable for chaining iterables from a lazy or infinite ``iterable``.
In turn, closing the ``chain`` only closes those iterables
already fetched from ``iterable``.
"""
async with ScopedIter(iterable) as iterables:
async for sub_iterable in iterables:
async with ScopedIter(sub_iterable) as iterator:
async for item in iterator:
yield item
return cls(_iterables=iterable)

def __anext__(self) -> Awaitable[T]:
return self._impl.__anext__()
return self._iterator.__anext__()

async def aclose(self) -> None:
for iterable in self._owned_iterators:
if hasattr(iterable, "aclose"):
await iterable.aclose()
await self._iterator.aclose()


async def compress(
Expand Down
51 changes: 51 additions & 0 deletions unittests/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,57 @@ async def test_chain(iterables):
)


class ACloseFacade:
"""Wrapper to check if an iterator has been closed"""

def __init__(self, iterable):
self.closed = False
self.__wrapped__ = iterable
self._iterator = a.iter(iterable)

async def __anext__(self):
if self.closed:
raise StopAsyncIteration()
return await self._iterator.__anext__()

def __aiter__(self):
return self

async def aclose(self):
if hasattr(self._iterator, "aclose"):
await self._iterator.aclose()
self.closed = True


@pytest.mark.parametrize("iterables", chains)
@sync
async def test_chain_close_auto(iterables):
"""Test that `chain` closes exhausted iterators"""
closeable_iterables = [ACloseFacade(iterable) for iterable in iterables]
assert await a.list(a.chain(*closeable_iterables)) == list(
itertools.chain(*iterables)
)
assert all(iterable.closed for iterable in closeable_iterables)


# insert a known filled iterable since chain closes all that are exhausted
@pytest.mark.parametrize("iterables", [([1], *chain) for chain in chains])
@pytest.mark.parametrize(
"chain_type, must_close",
[(lambda iterators: a.chain(*iterators), True), (a.chain.from_iterable, False)],
)
@sync
async def test_chain_close_partial(iterables, chain_type, must_close):
"""Test that `chain` closes owned iterators"""
closeable_iterables = [ACloseFacade(iterable) for iterable in iterables]
chain = chain_type(closeable_iterables)
assert await a.anext(chain) == next(itertools.chain(*iterables))
await chain.aclose()
assert all(iterable.closed == must_close for iterable in closeable_iterables[1:])
# closed chain must remain closed regardless of iterators
assert await a.anext(chain, "sentinel") == "sentinel"


compress_cases = [
(range(20), [idx % 2 for idx in range(20)]),
([1] * 5, [True, True, False, True, True]),
Expand Down

0 comments on commit b966a95

Please sign in to comment.