Skip to content

Commit

Permalink
[serve] Refactor UserCallableWrapper to return `concurrent.futures.…
Browse files Browse the repository at this point in the history
…Future` (#48449)

Previously, `UserCallableWrapper` wrapped results in an
`asyncio.Future`. This was convenient for use in the replica, but this
code will be re-used from a sync context for local testing mode, so I
need access to the underlying `concurrent.futures.Future` instead.

Alternatively we could add a duplicate path that returns the concurrent
future instead, but I actually prefer that this is more explicitly clear
that the code is running on a separate thread (will help defend against
mistaken bugs in the future).

I've also moved the `_user_callable_thread` and event loop into instance
attributes instead of class attributes. I must have missed something
when implementing this previously because this was basically trivial.

---------

Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
edoakes authored Oct 30, 2024
1 parent 41be27c commit 75d652c
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 156 deletions.
120 changes: 66 additions & 54 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import concurrent.futures
import inspect
import logging
import os
Expand Down Expand Up @@ -466,11 +467,13 @@ async def _call_user_generator(
async def _enqueue_thread_safe(item: Any):
self._event_loop.call_soon_threadsafe(result_queue.put_nowait, item)

call_user_method_future = self._user_callable_wrapper.call_user_method(
request_metadata,
request_args,
request_kwargs,
generator_result_callback=_enqueue_thread_safe,
call_user_method_future = asyncio.wrap_future(
self._user_callable_wrapper.call_user_method(
request_metadata,
request_args,
request_kwargs,
generator_result_callback=_enqueue_thread_safe,
)
)

while True:
Expand Down Expand Up @@ -521,8 +524,10 @@ async def handle_request(
"""Entrypoint for `stream=False` calls."""
request_metadata = pickle.loads(pickled_request_metadata)
with self._wrap_user_method_call(request_metadata, request_args):
return await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
return await asyncio.wrap_future(
self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)
)

async def handle_request_streaming(
Expand Down Expand Up @@ -593,8 +598,10 @@ async def handle_request_with_rejection(
):
yield result
else:
yield await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
yield await asyncio.wrap_future(
self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)
)

async def handle_request_from_java(
Expand All @@ -617,8 +624,10 @@ async def handle_request_from_java(
route=proto.route,
)
with self._wrap_user_method_call(request_metadata, request_args):
return await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
return await asyncio.wrap_future(
self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)
)

async def is_allocated(self) -> str:
Expand Down Expand Up @@ -663,16 +672,18 @@ async def initialize_and_get_metadata(
async with self._user_callable_initialized_lock:
initialization_start_time = time.time()
if not self._user_callable_initialized:
self._user_callable_asgi_app = (
await self._user_callable_wrapper.initialize_callable()
self._user_callable_asgi_app = await asyncio.wrap_future(
self._user_callable_wrapper.initialize_callable()
)
self._user_callable_initialized = True
self._set_internal_replica_context(
servable_object=self._user_callable_wrapper.user_callable
)
if deployment_config:
await self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
await asyncio.wrap_future(
self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
)
)

# A new replica should not be considered healthy until it passes
Expand Down Expand Up @@ -713,8 +724,10 @@ async def reconfigure(
self._configure_logger_and_profilers(deployment_config.logging_config)

if user_config_changed:
await self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
await asyncio.wrap_future(
self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
)
)

# We need to update internal replica context to reflect the new
Expand Down Expand Up @@ -789,7 +802,7 @@ async def perform_graceful_shutdown(self):
await self._drain_ongoing_requests()

try:
await self._user_callable_wrapper.call_destructor()
await asyncio.wrap_future(self._user_callable_wrapper.call_destructor())
except: # noqa: E722
# We catch a blanket exception since the constructor may still be
# running, so instance variables used by the destructor may not exist.
Expand All @@ -804,20 +817,18 @@ async def perform_graceful_shutdown(self):
await self._metrics_manager.shutdown()

async def check_health(self):
await self._user_callable_wrapper.call_user_health_check()
# If there's no user-defined health check, nothing runs on the user code event
# loop and no future is returned.
f: Optional[
concurrent.futures.Future
] = self._user_callable_wrapper.call_user_health_check()
if f is not None:
await asyncio.wrap_future(f)


class UserCallableWrapper:
"""Wraps a user-provided callable that is used to handle requests to a replica."""

# All interactions with user code run on this loop to avoid blocking the replica's
# main event loop.
# NOTE(edoakes): this is a class variable rather than an instance variable to
# enable writing the `_run_on_user_code_event_loop` decorator method (the decorator
# doesn't have access to `self` at class definition time).
_user_code_event_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
_user_code_event_loop_thread: Optional[threading.Thread] = None

def __init__(
self,
deployment_def: Callable,
Expand All @@ -842,38 +853,37 @@ def __init__(
# Will be populated in `initialize_callable`.
self._callable = None

# Start the `_user_code_event_loop_thread` singleton if needed.
if self._user_code_event_loop_thread is None:
# All interactions with user code run on this loop to avoid blocking the
# replica's main event loop.
self._user_code_event_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()

def _run_user_code_event_loop():
# Required so that calls to get the current running event loop work
# properly in user code.
asyncio.set_event_loop(self._user_code_event_loop)
self._user_code_event_loop.run_forever()
def _run_user_code_event_loop():
# Required so that calls to get the current running event loop work
# properly in user code.
asyncio.set_event_loop(self._user_code_event_loop)
self._user_code_event_loop.run_forever()

self._user_code_event_loop_thread = threading.Thread(
daemon=True,
target=_run_user_code_event_loop,
)
self._user_code_event_loop_thread.start()
self._user_code_event_loop_thread = threading.Thread(
daemon=True,
target=_run_user_code_event_loop,
)
self._user_code_event_loop_thread.start()

def _run_on_user_code_event_loop(f: Callable):
def _run_on_user_code_event_loop(f: Callable) -> Callable:
"""Decorator to run a coroutine method on the user code event loop.
The method will be modified to be a sync function that returns an
`asyncio.Future`.
The method will be modified to be a sync function that returns a
`concurrent.futures.Future`.
"""
assert inspect.iscoroutinefunction(
f
), "_run_on_user_code_event_loop can only be used on coroutine functions."

@wraps(f)
def wrapper(*args, **kwargs) -> asyncio.Future:
return asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(
f(*args, **kwargs),
UserCallableWrapper._user_code_event_loop,
)
def wrapper(self, *args, **kwargs) -> concurrent.futures.Future:
return asyncio.run_coroutine_threadsafe(
f(self, *args, **kwargs),
self._user_code_event_loop,
)

return wrapper
Expand Down Expand Up @@ -985,17 +995,13 @@ async def initialize_callable(self) -> Optional[ASGIApp]:
else None
)

@_run_on_user_code_event_loop
async def _call_user_health_check(self):
await self._call_func_or_gen(self._user_health_check)

def _raise_if_not_initialized(self, method_name: str):
if self._callable is None:
raise RuntimeError(
"`initialize_callable` must be called before `{method_name}`."
)

async def call_user_health_check(self):
def call_user_health_check(self) -> Optional[concurrent.futures.Future]:
self._raise_if_not_initialized("call_user_health_check")

# If the user provided a health check, call it on the user code thread. If user
Expand All @@ -1004,7 +1010,13 @@ async def call_user_health_check(self):
# To avoid this issue for basic cases without a user-defined health check, skip
# interacting with the user callable entirely.
if self._user_health_check is not None:
return await self._call_user_health_check()
return self._call_user_health_check()

return None

@_run_on_user_code_event_loop
async def _call_user_health_check(self):
await self._call_func_or_gen(self._user_health_check)

@_run_on_user_code_event_loop
async def call_reconfigure(self, user_config: Any):
Expand Down
Loading

0 comments on commit 75d652c

Please sign in to comment.