Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for wait_readable() and wait_writable() on ProactorEventLoop #831

Merged
merged 38 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a7335d9
Use ThreadSelectorEventLoop on Windows with ProactorEventLoop
davidbrochart Nov 9, 2024
d67a150
Removed AddThreadSelectorEventLoop
davidbrochart Nov 11, 2024
c1dd759
Skip test on Windows/Trio
davidbrochart Nov 11, 2024
76d23fa
Add back loop close and test on Windows/Trio
davidbrochart Nov 12, 2024
8b18582
Review
davidbrochart Nov 12, 2024
40c7347
Merge branch 'master' into selector-thread
agronholm Nov 16, 2024
1051fcd
Fix
davidbrochart Nov 17, 2024
01ccdf2
Removed unneeded anyio_backend_name
davidbrochart Nov 17, 2024
c381617
Added closing parenthesis
davidbrochart Nov 17, 2024
e21f323
Updated changelog
davidbrochart Nov 17, 2024
a379ccf
Use HasFileno from typeshed
davidbrochart Nov 18, 2024
872329a
Merge branch 'master' into selector-thread
davidbrochart Nov 18, 2024
6980062
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2024
19536de
Merge branch 'master' into selector-thread
davidbrochart Nov 25, 2024
3ea42bf
Close selector thread when root task is done
davidbrochart Nov 25, 2024
90b8eea
Merge branch 'master' into selector-thread
agronholm Nov 30, 2024
f68df73
Alternate implementation of the selector thread
agronholm Dec 1, 2024
0a94755
Fixed AttributeError -> KeyError in get_selector()
agronholm Dec 1, 2024
8fce306
Merge branch 'selector-thread' into selector-thread-alternate
agronholm Dec 1, 2024
1d685e4
Fixed AssertionError at exit
agronholm Dec 1, 2024
17140f3
Use FileDescriptorLike also in the test module
agronholm Dec 1, 2024
3fc63a4
Fixed timeout errors
agronholm Dec 1, 2024
84029d2
Fixed mypy errors
agronholm Dec 1, 2024
84f05ce
Fixed linting error
agronholm Dec 1, 2024
56615c1
Updated the changelog and the docs on wait_readable/wait_writable
agronholm Dec 1, 2024
65662c8
Refactored implementation to use a global selector for all event loops
agronholm Dec 1, 2024
d9ac670
Fixed test failures
agronholm Dec 1, 2024
24fdf30
Added explicit thread name
agronholm Dec 1, 2024
6826696
Really set the global selector
agronholm Dec 1, 2024
d793aa5
Remove fd from selector on exception
agronholm Dec 1, 2024
24b4a9b
Fixed events never getting removed from _(read|write)_events
agronholm Dec 2, 2024
7274d34
Addressed some review comments
agronholm Dec 2, 2024
e138761
Drain the buffer from the notify receive socket whenever it's flagged…
agronholm Dec 2, 2024
748270b
Moved remove_(reader|writer) to the else block
agronholm Dec 2, 2024
7c332d6
Don't skip the wait_socket tests on ProactorEventLoop
agronholm Dec 2, 2024
e19938c
Loop until all data is read
agronholm Dec 3, 2024
dce49fc
Added an implementation note
agronholm Dec 3, 2024
43d1dd4
Fixed wording of new paragraph
agronholm Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept
an object with a ``.fileno()`` method or an integer handle, and deprecated
their now obsolete versions (``wait_socket_readable()`` and
``wait_socket_writable()`` (PR by @davidbrochart)
``wait_socket_writable()``) (PR by @davidbrochart)
- Added support for ``wait_readable()`` and ``wait_writable()`` on ``ProactorEventLoop``
(used on asyncio + Windows by default)
- Fixed the return type annotations of ``readinto()`` and ``readinto1()`` methods in the
``anyio.AsyncFile`` class
(`#825 <https://github.com/agronholm/anyio/issues/825>`_)
Expand Down
60 changes: 35 additions & 25 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

if TYPE_CHECKING:
from _typeshed import HasFileno
from _typeshed import FileDescriptorLike
else:
FileDescriptorLike = object

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand Down Expand Up @@ -2734,7 +2736,7 @@ async def getnameinfo(
return await get_running_loop().getnameinfo(sockaddr, flags)

@classmethod
async def wait_readable(cls, obj: HasFileno | int) -> None:
async def wait_readable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
Expand All @@ -2746,25 +2748,29 @@ async def wait_readable(cls, obj: HasFileno | int) -> None:
obj = obj.fileno()

if read_events.get(obj):
raise BusyResourceError("reading from") from None
raise BusyResourceError("reading from")

loop = get_running_loop()
event = read_events[obj] = asyncio.Event()
loop.add_reader(obj, event.set)
event = asyncio.Event()
try:
loop.add_reader(obj, event.set)
remove_reader = loop.remove_reader
graingert marked this conversation as resolved.
Show resolved Hide resolved
except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

selector = get_selector()
selector.add_reader(obj, event.set)
remove_reader = selector.remove_reader

read_events[obj] = event
try:
await event.wait()
finally:
if read_events.pop(obj, None) is not None:
loop.remove_reader(obj)
readable = True
else:
readable = False

if not readable:
raise ClosedResourceError
remove_reader(obj)
del read_events[obj]

@classmethod
async def wait_writable(cls, obj: HasFileno | int) -> None:
async def wait_writable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
Expand All @@ -2776,22 +2782,26 @@ async def wait_writable(cls, obj: HasFileno | int) -> None:
obj = obj.fileno()

if write_events.get(obj):
raise BusyResourceError("writing to") from None
raise BusyResourceError("writing to")

loop = get_running_loop()
event = write_events[obj] = asyncio.Event()
loop.add_writer(obj, event.set)
event = asyncio.Event()
try:
loop.add_writer(obj, event.set)
remove_writer = loop.remove_writer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think this should be in the except's else clause

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought someone might pick up on that...

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

selector = get_selector()
selector.add_writer(obj, event.set)
remove_writer = selector.remove_writer

write_events[obj] = event
try:
await event.wait()
finally:
if write_events.pop(obj, None) is not None:
loop.remove_writer(obj)
writable = True
else:
writable = False

if not writable:
raise ClosedResourceError
del write_events[obj]
remove_writer(obj)

@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
Expand Down
150 changes: 150 additions & 0 deletions src/anyio/_core/_asyncio_selector_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from __future__ import annotations

import asyncio
import socket
import threading
from collections.abc import Callable
from selectors import EVENT_READ, EVENT_WRITE, DefaultSelector
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from _typeshed import FileDescriptorLike

_selector_lock = threading.Lock()
_selector: Selector | None = None


class Selector:
def __init__(self) -> None:
self._thread = threading.Thread(target=self.run, name="AnyIO socket selector")
self._selector = DefaultSelector()
self._send, self._receive = socket.socketpair()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both sockets should be non-blocking, and you should ignore BlockingIOError on the send side

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self._send.setblocking(False)
self._receive.setblocking(False)
self._selector.register(self._receive, EVENT_READ)
self._closed = False

def start(self) -> None:
self._thread.start()
threading._register_atexit(self._stop) # type: ignore[attr-defined]

def _stop(self) -> None:
global _selector
self._closed = True
self._notify_self()
self._send.close()
self._thread.join()
self._selector.unregister(self._receive)
self._receive.close()
self._selector.close()
_selector = None
assert (
not self._selector.get_map()
), "selector still has registered file descriptors after shutdown"

def _notify_self(self) -> None:
try:
self._send.send(b"\x00")
except BlockingIOError:
pass

def add_reader(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
loop = asyncio.get_running_loop()
try:
key = self._selector.get_key(fd)
except KeyError:
self._selector.register(fd, EVENT_READ, {EVENT_READ: (loop, callback)})
else:
if EVENT_READ in key.data:
raise ValueError(
"this file descriptor is already registered for reading"
)

key.data[EVENT_READ] = loop, callback
self._selector.modify(fd, key.events | EVENT_READ, key.data)

self._notify_self()

def add_writer(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
loop = asyncio.get_running_loop()
try:
key = self._selector.get_key(fd)
except KeyError:
self._selector.register(fd, EVENT_WRITE, {EVENT_WRITE: (loop, callback)})
else:
if EVENT_WRITE in key.data:
raise ValueError(
"this file descriptor is already registered for writing"
)

key.data[EVENT_WRITE] = loop, callback
self._selector.modify(fd, key.events | EVENT_WRITE, key.data)

self._notify_self()

def remove_reader(self, fd: FileDescriptorLike) -> bool:
try:
key = self._selector.get_key(fd)
except KeyError:
return False

if new_events := key.events ^ EVENT_READ:
del key.data[EVENT_READ]
self._selector.modify(fd, new_events, key.data)
else:
self._selector.unregister(fd)

return True

def remove_writer(self, fd: FileDescriptorLike) -> bool:
try:
key = self._selector.get_key(fd)
except KeyError:
return False

if new_events := key.events ^ EVENT_WRITE:
del key.data[EVENT_WRITE]
self._selector.modify(fd, new_events, key.data)
else:
self._selector.unregister(fd)

return True

def run(self) -> None:
while not self._closed:
for key, events in self._selector.select():
if key.fileobj is self._receive:
while True:
try:
self._receive.recv(4096)
graingert marked this conversation as resolved.
Show resolved Hide resolved
except BlockingIOError:
break

continue

if events & EVENT_READ:
loop, callback = key.data[EVENT_READ]
self.remove_reader(key.fd)
try:
loop.call_soon_threadsafe(callback)
except RuntimeError:
pass # the loop was already closed

if events & EVENT_WRITE:
loop, callback = key.data[EVENT_WRITE]
self.remove_writer(key.fd)
try:
loop.call_soon_threadsafe(callback)
except RuntimeError:
pass # the loop was already closed


def get_selector() -> Selector:
global _selector

with _selector_lock:
if _selector is None:
_selector = Selector()
_selector.start()

return _selector
31 changes: 11 additions & 20 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from ._tasks import create_task_group, move_on_after

if TYPE_CHECKING:
from _typeshed import HasFileno
from _typeshed import FileDescriptorLike
else:
HasFileno = object
FileDescriptorLike = object

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
Expand Down Expand Up @@ -609,9 +609,6 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:

Wait until the given socket has data to be read.

This does **NOT** work on Windows when using the asyncio backend with a proactor
event loop (default on py3.8+).

.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!

Expand Down Expand Up @@ -649,7 +646,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
return get_async_backend().wait_writable(sock.fileno())


def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
def wait_readable(obj: FileDescriptorLike) -> Awaitable[None]:
"""
Wait until the given object has data to be read.

Expand All @@ -663,10 +660,7 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
descriptors aren't supported, and neither are handles that refer to anything besides
a ``SOCKET``.

This does **NOT** work on Windows when using the asyncio backend with a proactor
event loop (default on py3.8+).

.. warning:: Only use this on raw sockets that have not been wrapped by any higher
.. warning:: Don't use this on raw sockets that have been wrapped by any higher
level constructs like socket streams!

:param obj: an object with a ``.fileno()`` method or an integer handle
Expand All @@ -679,25 +673,22 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
return get_async_backend().wait_readable(obj)


def wait_writable(obj: HasFileno | int) -> Awaitable[None]:
def wait_writable(obj: FileDescriptorLike) -> Awaitable[None]:
"""
Wait until the given object can be written to.

This does **NOT** work on Windows when using the asyncio backend with a proactor
graingert marked this conversation as resolved.
Show resolved Hide resolved
event loop (default on py3.8+).

.. seealso:: See the documentation of :func:`wait_readable` for the definition of
``obj``.

.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!

:param obj: an object with a ``.fileno()`` method or an integer handle
:raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
object to become writable
:raises ~anyio.BusyResourceError: if another task is already waiting for the object
to become writable

.. seealso:: See the documentation of :func:`wait_readable` for the definition of
``obj``.

.. warning:: Don't use this on raw sockets that have been wrapped by any higher
level constructs like socket streams!

"""
return get_async_backend().wait_writable(obj)

Expand Down
19 changes: 6 additions & 13 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from exceptiongroup import ExceptionGroup

if TYPE_CHECKING:
from _typeshed import HasFileno
from _typeshed import FileDescriptorLike

AnyIPAddressFamily = Literal[
AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
Expand Down Expand Up @@ -1858,16 +1858,7 @@ async def test_connect_tcp_getaddrinfo_context() -> None:

@pytest.mark.parametrize("socket_type", ["socket", "fd"])
@pytest.mark.parametrize("event", ["readable", "writable"])
async def test_wait_socket(
anyio_backend_name: str, event: str, socket_type: str
) -> None:
if anyio_backend_name == "asyncio" and platform.system() == "Windows":
import asyncio

policy = asyncio.get_event_loop_policy()
if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")

async def test_wait_socket(event: str, socket_type: str) -> None:
wait = wait_readable if event == "readable" else wait_writable

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock:
Expand All @@ -1880,8 +1871,10 @@ async def test_wait_socket(

conn, addr = server_sock.accept()
with conn:
sock_or_fd: HasFileno | int = conn.fileno() if socket_type == "fd" else conn
with fail_after(10):
sock_or_fd: FileDescriptorLike = (
conn.fileno() if socket_type == "fd" else conn
)
with fail_after(3):
await wait(sock_or_fd)
assert conn.recv(1024) == b"Hello, world"

Expand Down
Loading