Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[serve] Refactor UserCallableWrapper to return concurrent.futures.Future" #48468

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 54 additions & 66 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import concurrent.futures
import inspect
import logging
import os
Expand Down Expand Up @@ -467,13 +466,11 @@ async def _call_user_generator(
async def _enqueue_thread_safe(item: Any):
self._event_loop.call_soon_threadsafe(result_queue.put_nowait, item)

call_user_method_future = asyncio.wrap_future(
self._user_callable_wrapper.call_user_method(
request_metadata,
request_args,
request_kwargs,
generator_result_callback=_enqueue_thread_safe,
)
call_user_method_future = self._user_callable_wrapper.call_user_method(
request_metadata,
request_args,
request_kwargs,
generator_result_callback=_enqueue_thread_safe,
)

while True:
Expand Down Expand Up @@ -524,10 +521,8 @@ async def handle_request(
"""Entrypoint for `stream=False` calls."""
request_metadata = pickle.loads(pickled_request_metadata)
with self._wrap_user_method_call(request_metadata, request_args):
return await asyncio.wrap_future(
self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)
return await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)

async def handle_request_streaming(
Expand Down Expand Up @@ -598,10 +593,8 @@ async def handle_request_with_rejection(
):
yield result
else:
yield await asyncio.wrap_future(
self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)
yield await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)

async def handle_request_from_java(
Expand All @@ -624,10 +617,8 @@ async def handle_request_from_java(
route=proto.route,
)
with self._wrap_user_method_call(request_metadata, request_args):
return await asyncio.wrap_future(
self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)
return await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)

async def is_allocated(self) -> str:
Expand Down Expand Up @@ -672,18 +663,16 @@ async def initialize_and_get_metadata(
async with self._user_callable_initialized_lock:
initialization_start_time = time.time()
if not self._user_callable_initialized:
self._user_callable_asgi_app = await asyncio.wrap_future(
self._user_callable_wrapper.initialize_callable()
self._user_callable_asgi_app = (
await self._user_callable_wrapper.initialize_callable()
)
self._user_callable_initialized = True
self._set_internal_replica_context(
servable_object=self._user_callable_wrapper.user_callable
)
if deployment_config:
await asyncio.wrap_future(
self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
)
await self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
)

# A new replica should not be considered healthy until it passes
Expand Down Expand Up @@ -724,10 +713,8 @@ async def reconfigure(
self._configure_logger_and_profilers(deployment_config.logging_config)

if user_config_changed:
await asyncio.wrap_future(
self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
)
await self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
)

# We need to update internal replica context to reflect the new
Expand Down Expand Up @@ -802,7 +789,7 @@ async def perform_graceful_shutdown(self):
await self._drain_ongoing_requests()

try:
await asyncio.wrap_future(self._user_callable_wrapper.call_destructor())
await self._user_callable_wrapper.call_destructor()
except: # noqa: E722
# We catch a blanket exception since the constructor may still be
# running, so instance variables used by the destructor may not exist.
Expand All @@ -817,18 +804,20 @@ async def perform_graceful_shutdown(self):
await self._metrics_manager.shutdown()

async def check_health(self):
# If there's no user-defined health check, nothing runs on the user code event
# loop and no future is returned.
f: Optional[
concurrent.futures.Future
] = self._user_callable_wrapper.call_user_health_check()
if f is not None:
await asyncio.wrap_future(f)
await self._user_callable_wrapper.call_user_health_check()


class UserCallableWrapper:
"""Wraps a user-provided callable that is used to handle requests to a replica."""

# All interactions with user code run on this loop to avoid blocking the replica's
# main event loop.
# NOTE(edoakes): this is a class variable rather than an instance variable to
# enable writing the `_run_on_user_code_event_loop` decorator method (the decorator
# doesn't have access to `self` at class definition time).
_user_code_event_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
_user_code_event_loop_thread: Optional[threading.Thread] = None

def __init__(
self,
deployment_def: Callable,
Expand All @@ -853,37 +842,38 @@ def __init__(
# Will be populated in `initialize_callable`.
self._callable = None

# All interactions with user code run on this loop to avoid blocking the
# replica's main event loop.
self._user_code_event_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
# Start the `_user_code_event_loop_thread` singleton if needed.
if self._user_code_event_loop_thread is None:

def _run_user_code_event_loop():
# Required so that calls to get the current running event loop work
# properly in user code.
asyncio.set_event_loop(self._user_code_event_loop)
self._user_code_event_loop.run_forever()
def _run_user_code_event_loop():
# Required so that calls to get the current running event loop work
# properly in user code.
asyncio.set_event_loop(self._user_code_event_loop)
self._user_code_event_loop.run_forever()

self._user_code_event_loop_thread = threading.Thread(
daemon=True,
target=_run_user_code_event_loop,
)
self._user_code_event_loop_thread.start()
self._user_code_event_loop_thread = threading.Thread(
daemon=True,
target=_run_user_code_event_loop,
)
self._user_code_event_loop_thread.start()

def _run_on_user_code_event_loop(f: Callable) -> Callable:
def _run_on_user_code_event_loop(f: Callable):
"""Decorator to run a coroutine method on the user code event loop.

The method will be modified to be a sync function that returns a
`concurrent.futures.Future`.
The method will be modified to be a sync function that returns an
`asyncio.Future`.
"""
assert inspect.iscoroutinefunction(
f
), "_run_on_user_code_event_loop can only be used on coroutine functions."

@wraps(f)
def wrapper(self, *args, **kwargs) -> concurrent.futures.Future:
return asyncio.run_coroutine_threadsafe(
f(self, *args, **kwargs),
self._user_code_event_loop,
def wrapper(*args, **kwargs) -> asyncio.Future:
return asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(
f(*args, **kwargs),
UserCallableWrapper._user_code_event_loop,
)
)

return wrapper
Expand Down Expand Up @@ -995,13 +985,17 @@ async def initialize_callable(self) -> Optional[ASGIApp]:
else None
)

@_run_on_user_code_event_loop
async def _call_user_health_check(self):
await self._call_func_or_gen(self._user_health_check)

def _raise_if_not_initialized(self, method_name: str):
if self._callable is None:
raise RuntimeError(
"`initialize_callable` must be called before `{method_name}`."
)

def call_user_health_check(self) -> Optional[concurrent.futures.Future]:
async def call_user_health_check(self):
self._raise_if_not_initialized("call_user_health_check")

# If the user provided a health check, call it on the user code thread. If user
Expand All @@ -1010,13 +1004,7 @@ def call_user_health_check(self) -> Optional[concurrent.futures.Future]:
# To avoid this issue for basic cases without a user-defined health check, skip
# interacting with the user callable entirely.
if self._user_health_check is not None:
return self._call_user_health_check()

return None

@_run_on_user_code_event_loop
async def _call_user_health_check(self):
await self._call_func_or_gen(self._user_health_check)
return await self._call_user_health_check()

@_run_on_user_code_event_loop
async def call_reconfigure(self, user_config: Any):
Expand Down
Loading