diff --git a/asyncio/run.py b/asyncio/run.py index 44ae4ab3..6c0601cb 100644 --- a/asyncio/run.py +++ b/asyncio/run.py @@ -2,6 +2,7 @@ __all__ = ['run', 'forever'] +import inspect import threading from . import coroutines @@ -67,8 +68,8 @@ async def main(): if not isinstance(threading.current_thread(), threading._MainThread): raise RuntimeError( "asyncio.run() must be called from the main thread") - if not coroutines.iscoroutine(coro): - raise ValueError("a coroutine was expected, got {!r}".format(coro)) + # if not coroutines.iscoroutine(coro): + # raise ValueError("a coroutine was expected, got {!r}".format(coro)) loop = events.new_event_loop() try: @@ -77,15 +78,26 @@ async def main(): if debug: loop.set_debug(True) - task = loop.create_task(coro) - task.add_done_callback(lambda task: loop.stop()) + if inspect.isasyncgen(coro): + result = None + loop.run_until_complete(coro.asend(None)) + try: + loop.run_forever() + except BaseException as ex: + try: + loop.run_until_complete(coro.athrow(ex)) + except StopAsyncIteration as ex: + if ex.args: + result = ex.args[0] + else: + try: + loop.run_until_complete(coro.asend(None)) + except StopAsyncIteration as ex: + if ex.args: + result = ex.args[0] - try: - loop.run_forever() - except BaseException as ex: - result = loop.run_until_complete(task) else: - result = task.result() + result = loop.run_until_complete(coro) try: # `shutdown_asyncgens` was added in Python 3.6; not all