Skip to content

Commit

Permalink
async for loop
Browse files Browse the repository at this point in the history
  • Loading branch information
amoffat committed Sep 6, 2021
1 parent f690eaa commit e9511de
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
47 changes: 45 additions & 2 deletions sh.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from collections import deque

import asyncio

try:
from collections.abc import Mapping
except ImportError:
Expand Down Expand Up @@ -63,6 +65,7 @@
import warnings
import weakref
from queue import Queue, Empty
from asyncio import Queue as AQueue
from io import StringIO, BytesIO
from shlex import quote as shlex_quote

Expand Down Expand Up @@ -650,6 +653,12 @@ def __init__(self, cmd, call_args, stdin, stdout, stderr):
should_wait = True
spawn_process = True

# if we're using an async for loop on this object, we need to put the underlying
# iterable in no-block mode. however, we will only know if we're using an async
# for loop after this object is constructed. so we'll set it to False now, but
# then later set it to True if we need it
self._force_noblock = False

# this is used to track if we've already raised StopIteration, and if we
# have, raise it immediately again if the user tries to call next() on
# us. https://github.com/amoffat/sh/issues/273
Expand Down Expand Up @@ -860,7 +869,7 @@ def __next__(self):
True, self.call_args["iter_poll_time"]
)
except Empty:
if self.call_args["iter_noblock"]:
if self.call_args["iter_noblock"] or self._force_noblock:
return errno.EWOULDBLOCK
else:
if chunk is None:
Expand All @@ -874,6 +883,39 @@ def __next__(self):
except UnicodeDecodeError:
return chunk

def __aiter__(self):
# maxsize is critical to making sure our queue_connector function below yields
# when it awaits _aio_queue.put(chunk). if we didn't have a maxsize, our loop
# would happily iterate through `chunk in self` and put onto the queue without
# any blocking, and therefore no yielding, which would prevent other coroutines
# from running.
self._aio_queue = AQueue(maxsize=1)
self._force_noblock = True

# the sole purpose of this coroutine is to connect our pipe_queue (which is
# being populated by a thread) to an asyncio-friendly queue. then, in __anext__,
# we can iterate over that asyncio queue.
async def queue_connector():
for chunk in self:
if chunk == errno.EWOULDBLOCK:
pass
else:
await self._aio_queue.put(chunk)
await self._aio_queue.put(None)

if sys.version_info < (3, 7, 0):
asyncio.ensure_future(queue_connector())
else:
asyncio.create_task(queue_connector())

return self

async def __anext__(self):
chunk = await self._aio_queue.get()
if chunk is not None:
return chunk
else:
raise StopAsyncIteration

def __exit__(self, exc_type, exc_val, exc_tb):
if self.call_args["with"] and get_prepend_stack():
Expand Down Expand Up @@ -1138,6 +1180,7 @@ class Command(object):
"piped": None,
"iter": None,
"iter_noblock": None,
"async": False,
# the amount of time to sleep between polling for the iter output queue
"iter_poll_time": 0.1,
"ok_code": 0,
Expand Down Expand Up @@ -2371,7 +2414,7 @@ def is_alive(self):
# thread, which is attempting to call wait(). by introducing a tiny sleep
# (ugh), this seems to prevent other threads from equally attempting to
# acquire the lock. TODO find out if this is a general python bug
time.sleep(0.00001)
time.sleep(0.1)
if self.exit_code is not None:
return False, self.exit_code
return True, self.exit_code
Expand Down
40 changes: 40 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import platform
import pty
import resource

import asyncio

import sh
import signal
import stat
Expand Down Expand Up @@ -1670,6 +1673,43 @@ def test_iter_generator(self):
self.assertEqual(len(out), 42)
self.assertEqual(sum(out), 861)

def test_async_iter(self):
py = create_tmp_test(
"""
import os
import time
for i in range(5):
print(i)
"""
)
from asyncio.queues import Queue as AQueue

q = AQueue()

# this list will prove that our coroutines are yielding to eachother as each
# line is produced
alternating = []

async def producer(q):
async for line in python(py.name, _iter=True):
alternating.append(1)
await q.put(int(line.strip()))

await q.put(None)

async def consumer(q):
while True:
line = await q.get()
if line is None:
return
alternating.append(2)

loop = asyncio.get_event_loop()
res = asyncio.gather(producer(q), consumer(q))
loop.run_until_complete(res)
self.assertListEqual(alternating, [1, 2, 1, 2, 1, 2, 1, 2, 1, 2])

def test_handle_both_out_and_err(self):
py = create_tmp_test(
"""
Expand Down

0 comments on commit e9511de

Please sign in to comment.