Skip to content

Commit

Permalink
Async functions and async generator functions with the every option…
Browse files Browse the repository at this point in the history
… to work (#6395)

* Extend `get_continuous_fn()` to deal with async functions and async generator functions

* add changeset

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
whitphx and gradio-pr-bot authored Nov 13, 2023
1 parent 03491ef commit 8ef48f8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
5 changes: 5 additions & 0 deletions .changeset/cold-gifts-tickle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Async functions and async generator functions with the `every` option to work
7 changes: 6 additions & 1 deletion gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from io import BytesIO
from numbers import Number
from pathlib import Path
from types import GeneratorType
from types import AsyncGeneratorType, GeneratorType
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -604,6 +604,11 @@ async def continuous_coro(*args):
if isinstance(output, GeneratorType):
for item in output:
yield item
elif isinstance(output, AsyncGeneratorType):
async for item in output:
yield item
elif inspect.isawaitable(output):
yield await output
else:
yield output
await asyncio.sleep(every)
Expand Down
33 changes: 33 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,39 @@ def list_yield(x): # new condition
assert [1, 1] == await agener_list.__anext__()
assert [1, 1, 1] == await agener_list.__anext__()

@pytest.mark.asyncio
async def test_get_continuous_fn_with_async_function(self):
async def async_int_return(x): # for origin condition
return x + 1

agen_int_return = get_continuous_fn(fn=async_int_return, every=0.01)
agener_int_return = agen_int_return(1)
assert await agener_int_return.__anext__() == 2
assert await agener_int_return.__anext__() == 2

@pytest.mark.asyncio
async def test_get_continuous_fn_with_async_generator(self):
async def async_int_yield(x): # new condition
for _i in range(2):
yield x
x += 1

async def async_list_yield(x): # new condition
for _i in range(2):
yield x
x += [1]

agen_int_yield = get_continuous_fn(fn=async_int_yield, every=0.01)
agen_list_yield = get_continuous_fn(fn=async_list_yield, every=0.01)
agener_int = agen_int_yield(1) # Primitive
agener_list = agen_list_yield([1]) # Reference
assert await agener_int.__anext__() == 1
assert await agener_int.__anext__() == 2
assert await agener_int.__anext__() == 1
assert [1] == await agener_list.__anext__()
assert [1, 1] == await agener_list.__anext__()
assert [1, 1, 1] == await agener_list.__anext__()


def test_tex2svg_preserves_matplotlib_backend():
import matplotlib
Expand Down

0 comments on commit 8ef48f8

Please sign in to comment.