Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support AnyIO #2045

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ classifiers = [
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dependencies = ["cffi; implementation_name == 'pypy'"]
dependencies = [
"cffi; implementation_name == 'pypy'",
"anyioutils >=0.4.2"
]
description = "Python bindings for 0MQ"
readme = "README.md"

Expand Down Expand Up @@ -144,7 +147,7 @@ search = '__version__: str = "{current_version}"'
[tool.cibuildwheel]
build-verbosity = "1"
free-threaded-support = true
test-requires = ["pytest>=6", "importlib_metadata"]
test-requires = ["pytest>=6", "importlib_metadata", "exceptiongroup;python_version<'3.11'"]
test-command = "pytest -vsx {package}/tools/test_wheel.py"

[tool.cibuildwheel.linux]
Expand Down
220 changes: 108 additions & 112 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
from multiprocessing import Process

import pytest
from anyio import create_task_group, move_on_after, sleep
from anyioutils import CancelledError, create_task
from pytest import mark

import zmq
import zmq.asyncio as zaio

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup


pytestmark = pytest.mark.anyio


@pytest.fixture
def Context(event_loop):
Expand Down Expand Up @@ -46,23 +54,17 @@ def test_instance_subclass_second(context):
async def test_recv_multipart(context, create_bound_pair):
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_multipart()
assert not f.done()
await a.send(b"hi")
recvd = await f
assert recvd == [b"hi"]
assert await f == [b"hi"]


async def test_recv(create_bound_pair):
a, b = create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv()
assert not f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.done()
assert f1.result() == b"hi"
assert recvd == b"there"
assert await f1 == b"hi"
assert await f2 == b"there"


@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
Expand All @@ -72,82 +74,70 @@ async def test_recv_timeout(push_pull):
f1 = b.recv()
b.rcvtimeo = 1000
f2 = b.recv_multipart()
with pytest.raises(zmq.Again):
with pytest.raises(ExceptionGroup) as excinfo:
await f1
assert excinfo.group_contains(zmq.Again)
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f2.done()
assert recvd == [b"hi", b"there"]


@mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
async def test_send_timeout(socket):
s = socket(zmq.PUSH)
s.sndtimeo = 100
with pytest.raises(zmq.Again):
with pytest.raises(ExceptionGroup) as excinfo:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this isn't required when it's wrapped up. A single coroutine should raise a single error, right? This would be a major breaking change and a significant degradation of the API.

await s.send(b"not going anywhere")
assert excinfo.group_contains(zmq.Again)


async def test_recv_string(push_pull):
a, b = push_pull
f = b.recv_string()
assert not f.done()
msg = "πøøπ"
await a.send_string(msg)
recvd = await f
assert f.done()
assert f.result() == msg
assert recvd == msg


async def test_recv_json(push_pull):
a, b = push_pull
f = b.recv_json()
assert not f.done()
obj = dict(a=5)
await a.send_json(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj


async def test_recv_json_cancelled(push_pull):
a, b = push_pull
f = b.recv_json()
assert not f.done()
f.cancel()
# cycle eventloop to allow cancel events to fire
await asyncio.sleep(0)
obj = dict(a=5)
await a.send_json(obj)
# CancelledError change in 3.8 https://bugs.python.org/issue32528
if sys.version_info < (3, 8):
with pytest.raises(CancelledError):
async with create_task_group() as tg:
a, b = push_pull
f = create_task(b.recv_json(), tg)
f.cancel(raise_exception=False)
# cycle eventloop to allow cancel events to fire
await sleep(0)
obj = dict(a=5)
await a.send_json(obj)
recvd = await f.wait()
assert f.cancelled()
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
with move_on_after(5):
recvd = await f
else:
with pytest.raises(asyncio.exceptions.CancelledError):
recvd = await f
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await asyncio.sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == obj
assert recvd == obj


async def test_recv_pyobj(push_pull):
a, b = push_pull
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
await a.send_pyobj(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj


Expand Down Expand Up @@ -206,85 +196,90 @@ async def test_custom_serialize_error(dealer_router):
async def test_recv_dontwait(push_pull):
push, pull = push_pull
f = pull.recv(zmq.DONTWAIT)
with pytest.raises(zmq.Again):
with pytest.raises(BaseExceptionGroup) as excinfo:
await f
assert excinfo.group_contains(zmq.Again)
await push.send(b"ping")
await pull.poll() # ensure message will be waiting
f = pull.recv(zmq.DONTWAIT)
assert f.done()
msg = await f
msg = await pull.recv(zmq.DONTWAIT)
assert msg == b"ping"


async def test_recv_cancel(push_pull):
a, b = push_pull
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.cancelled()
assert f2.done()
assert recvd == [b"hi", b"there"]
async with create_task_group() as tg:
a, b = push_pull
f1 = create_task(b.recv(), tg)
f2 = create_task(b.recv_multipart(), tg)
f1.cancel(raise_exception=False)
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2.wait()
assert f1.cancelled()
assert f2.done()
assert recvd == [b"hi", b"there"]


async def test_poll(push_pull):
a, b = push_pull
f = b.poll(timeout=0)
await asyncio.sleep(0)
assert f.result() == 0
async with create_task_group() as tg:
a, b = push_pull
f = create_task(b.poll(timeout=0), tg)
await sleep(0.01)
assert f.result() == 0

f = b.poll(timeout=1)
assert not f.done()
evt = await f
f = create_task(b.poll(timeout=1), tg)
assert not f.done()
evt = await f.wait()

assert evt == 0
assert evt == 0

f = b.poll(timeout=1000)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b"hi", b"there"]
f = create_task(b.poll(timeout=1000), tg)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f.wait()
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b"hi", b"there"]


async def test_poll_base_socket(sockets):
ctx = zmq.Context()
url = "inproc://test"
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
sockets.extend([a, b])
a.bind(url)
b.connect(url)

poller = zaio.Poller()
poller.register(b, zmq.POLLIN)

f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b"hi", b"there"]
async with create_task_group() as tg:
ctx = zmq.Context()
url = "inproc://test"
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
sockets.extend([a, b])
a.bind(url)
b.connect(url)

poller = zaio.Poller()
poller.register(b, zmq.POLLIN)

f = create_task(poller.poll(timeout=1000), tg)
assert not f.done()
a.send_multipart([b"hi", b"there"])
evt = await f.wait()
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b"hi", b"there"]


async def test_poll_on_closed_socket(push_pull):
a, b = push_pull
with pytest.raises(BaseExceptionGroup) as excinfo:
async with create_task_group() as tg:
a, b = push_pull

f = b.poll(timeout=1)
b.close()
f = create_task(b.poll(timeout=1), tg)
b.close()

# The test might stall if we try to await f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
await asyncio.sleep(0)
if f.cancelled():
break
assert f.cancelled()
# The test might stall if we try to await f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
await sleep(0)
if f.cancelled():
break
assert f.done()
assert excinfo.group_contains(zmq.error.ZMQError)


@pytest.mark.skipif(
Expand Down Expand Up @@ -344,16 +339,17 @@ def test_shadow():


async def test_poll_leak():
ctx = zmq.asyncio.Context()
with ctx, ctx.socket(zmq.PULL) as s:
assert len(s._recv_futures) == 0
for i in range(10):
f = asyncio.ensure_future(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN))
f.cancel()
await asyncio.sleep(0)
# one more sleep allows further chained cleanup
await asyncio.sleep(0.1)
assert len(s._recv_futures) == 0
async with create_task_group() as tg:
ctx = zmq.asyncio.Context()
with ctx, ctx.socket(zmq.PULL) as s:
assert len(s._recv_futures) == 0
for i in range(10):
f = create_task(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN), tg)
f.cancel(raise_exception=False)
await sleep(0)
# one more sleep allows further chained cleanup
await sleep(0.1)
assert len(s._recv_futures) == 0


class ProcessForTeardownTest(Process):
Expand Down
5 changes: 2 additions & 3 deletions tests/test_ioloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
_tornado = True


def setup():
if not _tornado:
pytest.skip("requires tornado")
if not _tornado:
pytest.skip("requires tornado", allow_module_level=True)


def test_ioloop():
Expand Down
Loading
Loading