From 1088c925bff8b18f4d5ae6d6089fb0f62335f468 Mon Sep 17 00:00:00 2001 From: Cuong Nguyen <128072568+can-anyscale@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:22:07 -0700 Subject: [PATCH] =?UTF-8?q?Revert=20"[serve]=20Refactor=20`UserCallableWra?= =?UTF-8?q?pper`=20to=20return=20`concurrent.futures.=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 75d652ccb85b61c391458123dc9b42737400a154. --- python/ray/serve/_private/replica.py | 120 +++++----- .../tests/unit/test_user_callable_wrapper.py | 213 +++++++++--------- 2 files changed, 156 insertions(+), 177 deletions(-) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index c0bbfdf9d9bd9..5717182b15ab3 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -1,5 +1,4 @@ import asyncio -import concurrent.futures import inspect import logging import os @@ -467,13 +466,11 @@ async def _call_user_generator( async def _enqueue_thread_safe(item: Any): self._event_loop.call_soon_threadsafe(result_queue.put_nowait, item) - call_user_method_future = asyncio.wrap_future( - self._user_callable_wrapper.call_user_method( - request_metadata, - request_args, - request_kwargs, - generator_result_callback=_enqueue_thread_safe, - ) + call_user_method_future = self._user_callable_wrapper.call_user_method( + request_metadata, + request_args, + request_kwargs, + generator_result_callback=_enqueue_thread_safe, ) while True: @@ -524,10 +521,8 @@ async def handle_request( """Entrypoint for `stream=False` calls.""" request_metadata = pickle.loads(pickled_request_metadata) with self._wrap_user_method_call(request_metadata, request_args): - return await asyncio.wrap_future( - self._user_callable_wrapper.call_user_method( - request_metadata, request_args, request_kwargs - ) + return await self._user_callable_wrapper.call_user_method( + request_metadata, request_args, request_kwargs ) async def handle_request_streaming( @@ -598,10 +593,8 @@ async def handle_request_with_rejection( ): yield result else: - yield await asyncio.wrap_future( - self._user_callable_wrapper.call_user_method( - request_metadata, request_args, request_kwargs - ) + yield await self._user_callable_wrapper.call_user_method( + request_metadata, request_args, request_kwargs ) async def handle_request_from_java( @@ -624,10 +617,8 @@ async def handle_request_from_java( route=proto.route, ) with self._wrap_user_method_call(request_metadata, request_args): - return await asyncio.wrap_future( - self._user_callable_wrapper.call_user_method( - request_metadata, request_args, request_kwargs - ) + return await self._user_callable_wrapper.call_user_method( + request_metadata, request_args, request_kwargs ) async def is_allocated(self) -> str: @@ -672,18 +663,16 @@ async def initialize_and_get_metadata( async with self._user_callable_initialized_lock: initialization_start_time = time.time() if not self._user_callable_initialized: - self._user_callable_asgi_app = await asyncio.wrap_future( - self._user_callable_wrapper.initialize_callable() + self._user_callable_asgi_app = ( + await self._user_callable_wrapper.initialize_callable() ) self._user_callable_initialized = True self._set_internal_replica_context( servable_object=self._user_callable_wrapper.user_callable ) if deployment_config: - await asyncio.wrap_future( - self._user_callable_wrapper.call_reconfigure( - deployment_config.user_config - ) + await self._user_callable_wrapper.call_reconfigure( + deployment_config.user_config ) # A new replica should not be considered healthy until it passes @@ -724,10 +713,8 @@ async def reconfigure( self._configure_logger_and_profilers(deployment_config.logging_config) if user_config_changed: - await asyncio.wrap_future( - self._user_callable_wrapper.call_reconfigure( - deployment_config.user_config - ) + await self._user_callable_wrapper.call_reconfigure( + deployment_config.user_config ) # We need to update internal replica context to reflect the new @@ -802,7 +789,7 @@ async def perform_graceful_shutdown(self): await self._drain_ongoing_requests() try: - await asyncio.wrap_future(self._user_callable_wrapper.call_destructor()) + await self._user_callable_wrapper.call_destructor() except: # noqa: E722 # We catch a blanket exception since the constructor may still be # running, so instance variables used by the destructor may not exist. @@ -817,18 +804,20 @@ async def perform_graceful_shutdown(self): await self._metrics_manager.shutdown() async def check_health(self): - # If there's no user-defined health check, nothing runs on the user code event - # loop and no future is returned. - f: Optional[ - concurrent.futures.Future - ] = self._user_callable_wrapper.call_user_health_check() - if f is not None: - await asyncio.wrap_future(f) + await self._user_callable_wrapper.call_user_health_check() class UserCallableWrapper: """Wraps a user-provided callable that is used to handle requests to a replica.""" + # All interactions with user code run on this loop to avoid blocking the replica's + # main event loop. + # NOTE(edoakes): this is a class variable rather than an instance variable to + # enable writing the `_run_on_user_code_event_loop` decorator method (the decorator + # doesn't have access to `self` at class definition time). + _user_code_event_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + _user_code_event_loop_thread: Optional[threading.Thread] = None + def __init__( self, deployment_def: Callable, @@ -853,37 +842,38 @@ def __init__( # Will be populated in `initialize_callable`. self._callable = None - # All interactions with user code run on this loop to avoid blocking the - # replica's main event loop. - self._user_code_event_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + # Start the `_user_code_event_loop_thread` singleton if needed. + if self._user_code_event_loop_thread is None: - def _run_user_code_event_loop(): - # Required so that calls to get the current running event loop work - # properly in user code. - asyncio.set_event_loop(self._user_code_event_loop) - self._user_code_event_loop.run_forever() + def _run_user_code_event_loop(): + # Required so that calls to get the current running event loop work + # properly in user code. + asyncio.set_event_loop(self._user_code_event_loop) + self._user_code_event_loop.run_forever() - self._user_code_event_loop_thread = threading.Thread( - daemon=True, - target=_run_user_code_event_loop, - ) - self._user_code_event_loop_thread.start() + self._user_code_event_loop_thread = threading.Thread( + daemon=True, + target=_run_user_code_event_loop, + ) + self._user_code_event_loop_thread.start() - def _run_on_user_code_event_loop(f: Callable) -> Callable: + def _run_on_user_code_event_loop(f: Callable): """Decorator to run a coroutine method on the user code event loop. - The method will be modified to be a sync function that returns a - `concurrent.futures.Future`. + The method will be modified to be a sync function that returns an + `asyncio.Future`. """ assert inspect.iscoroutinefunction( f ), "_run_on_user_code_event_loop can only be used on coroutine functions." @wraps(f) - def wrapper(self, *args, **kwargs) -> concurrent.futures.Future: - return asyncio.run_coroutine_threadsafe( - f(self, *args, **kwargs), - self._user_code_event_loop, + def wrapper(*args, **kwargs) -> asyncio.Future: + return asyncio.wrap_future( + asyncio.run_coroutine_threadsafe( + f(*args, **kwargs), + UserCallableWrapper._user_code_event_loop, + ) ) return wrapper @@ -995,13 +985,17 @@ async def initialize_callable(self) -> Optional[ASGIApp]: else None ) + @_run_on_user_code_event_loop + async def _call_user_health_check(self): + await self._call_func_or_gen(self._user_health_check) + def _raise_if_not_initialized(self, method_name: str): if self._callable is None: raise RuntimeError( "`initialize_callable` must be called before `{method_name}`." ) - def call_user_health_check(self) -> Optional[concurrent.futures.Future]: + async def call_user_health_check(self): self._raise_if_not_initialized("call_user_health_check") # If the user provided a health check, call it on the user code thread. If user @@ -1010,13 +1004,7 @@ def call_user_health_check(self) -> Optional[concurrent.futures.Future]: # To avoid this issue for basic cases without a user-defined health check, skip # interacting with the user callable entirely. if self._user_health_check is not None: - return self._call_user_health_check() - - return None - - @_run_on_user_code_event_loop - async def _call_user_health_check(self): - await self._call_func_or_gen(self._user_health_check) + return await self._call_user_health_check() @_run_on_user_code_event_loop async def call_reconfigure(self, user_config: Any): diff --git a/python/ray/serve/tests/unit/test_user_callable_wrapper.py b/python/ray/serve/tests/unit/test_user_callable_wrapper.py index d341c62db0162..53f91bda1ebbe 100644 --- a/python/ray/serve/tests/unit/test_user_callable_wrapper.py +++ b/python/ray/serve/tests/unit/test_user_callable_wrapper.py @@ -1,5 +1,4 @@ import asyncio -import concurrent.futures import pickle import sys import threading @@ -124,102 +123,93 @@ def _make_request_metadata( ) -def test_calling_initialize_twice(): +@pytest.mark.asyncio +async def test_calling_initialize_twice(): user_callable_wrapper = _make_user_callable_wrapper() - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() assert isinstance(user_callable_wrapper.user_callable, BasicClass) with pytest.raises(RuntimeError): - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() -def test_calling_methods_before_initialize(): +@pytest.mark.asyncio +async def test_calling_methods_before_initialize(): user_callable_wrapper = _make_user_callable_wrapper() with pytest.raises(RuntimeError): - user_callable_wrapper.call_user_method(None, tuple(), dict()).result() + await user_callable_wrapper.call_user_method(None, tuple(), dict()) with pytest.raises(RuntimeError): - user_callable_wrapper.call_user_health_check().result() + await user_callable_wrapper.call_user_health_check() with pytest.raises(RuntimeError): - user_callable_wrapper.call_reconfigure(None).result() + await user_callable_wrapper.call_reconfigure(None) -def test_basic_class_callable(): +@pytest.mark.asyncio +async def test_basic_class_callable(): user_callable_wrapper = _make_user_callable_wrapper() - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() # Call non-generator method with is_streaming. request_metadata = _make_request_metadata(is_streaming=True) with pytest.raises(RayTaskError, match="did not return a generator."): - user_callable_wrapper.call_user_method( - request_metadata, tuple(), dict() - ).result() + await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) # Test calling default sync `__call__` method. request_metadata = _make_request_metadata() assert ( - user_callable_wrapper.call_user_method( - request_metadata, tuple(), dict() - ).result() + await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) ) == "hi" assert ( - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, ("-arg",), dict() - ).result() - == "hi-arg" - ) + ) + ) == "hi-arg" assert ( - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, tuple(), {"suffix": "-kwarg"} - ).result() - == "hi-kwarg" - ) + ) + ) == "hi-kwarg" with pytest.raises(RayTaskError, match="uh-oh"): - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, tuple(), {"raise_exception": True} - ).result() + ) # Call non-generator async method with is_streaming. request_metadata = _make_request_metadata( call_method="call_async", is_streaming=True ) with pytest.raises(RayTaskError, match="did not return a generator."): - user_callable_wrapper.call_user_method( - request_metadata, tuple(), dict() - ).result() + await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) # Test calling `call_async` method. request_metadata = _make_request_metadata(call_method="call_async") assert ( - user_callable_wrapper.call_user_method( - request_metadata, tuple(), dict() - ).result() - == "hi" - ) + await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) + ) == "hi" assert ( - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, ("-arg",), dict() - ).result() - == "hi-arg" - ) + ) + ) == "hi-arg" assert ( - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, tuple(), {"suffix": "-kwarg"} - ).result() - == "hi-kwarg" - ) + ) + ) == "hi-kwarg" with pytest.raises(RayTaskError, match="uh-oh"): - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, tuple(), {"raise_exception": True} - ).result() + ) -def test_basic_class_callable_generators(): +@pytest.mark.asyncio +async def test_basic_class_callable_generators(): user_callable_wrapper = _make_user_callable_wrapper() - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() result_list = [] @@ -233,28 +223,28 @@ async def append_to_list(item: Any): with pytest.raises( RayTaskError, match="Method 'call_generator' returned a generator." ): - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ).result() + ) # Call sync generator. request_metadata = _make_request_metadata( call_method="call_generator", is_streaming=True ) - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ).result() + ) assert result_list == list(range(10)) result_list.clear() # Call sync generator raising exception. with pytest.raises(RayTaskError, match="uh-oh"): - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (10,), {"raise_exception": True}, generator_result_callback=append_to_list, - ).result() + ) assert result_list == [0] result_list.clear() @@ -265,69 +255,67 @@ async def append_to_list(item: Any): with pytest.raises( RayTaskError, match="Method 'call_async_generator' returned a generator." ): - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ).result() + ) # Call async generator. request_metadata = _make_request_metadata( call_method="call_async_generator", is_streaming=True ) - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ).result() + ) assert result_list == list(range(10)) result_list.clear() # Call async generator raising exception. with pytest.raises(RayTaskError, match="uh-oh"): - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (10,), {"raise_exception": True}, generator_result_callback=append_to_list, - ).result() + ) assert result_list == [0] +@pytest.mark.asyncio @pytest.mark.parametrize("fn", [basic_sync_function, basic_async_function]) -def test_basic_function_callable(fn: Callable): +async def test_basic_function_callable(fn: Callable): user_callable_wrapper = _make_user_callable_wrapper(fn) - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() # Call non-generator function with is_streaming. request_metadata = _make_request_metadata(is_streaming=True) with pytest.raises(RayTaskError, match="did not return a generator."): - user_callable_wrapper.call_user_method( - request_metadata, tuple(), dict() - ).result() + await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) request_metadata = _make_request_metadata() assert ( - user_callable_wrapper.call_user_method( - request_metadata, tuple(), dict() - ).result() + await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) ) == "hi" assert ( - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, ("-arg",), dict() - ).result() + ) ) == "hi-arg" assert ( - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, tuple(), {"suffix": "-kwarg"} - ).result() + ) ) == "hi-kwarg" with pytest.raises(RayTaskError, match="uh-oh"): - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, tuple(), {"raise_exception": True} - ).result() + ) +@pytest.mark.asyncio @pytest.mark.parametrize("fn", [basic_sync_generator, basic_async_generator]) -def test_basic_function_callable_generators(fn: Callable): +async def test_basic_function_callable_generators(fn: Callable): user_callable_wrapper = _make_user_callable_wrapper(fn) - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() result_list = [] @@ -339,28 +327,28 @@ async def append_to_list(item: Any): with pytest.raises( RayTaskError, match=f"Method '{fn.__name__}' returned a generator." ): - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ).result() + ) # Call generator function. request_metadata = _make_request_metadata( call_method="call_generator", is_streaming=True ) - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ).result() + ) assert result_list == list(range(10)) result_list.clear() # Call generator function raising exception. with pytest.raises(RayTaskError, match="uh-oh"): - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (10,), {"raise_exception": True}, generator_result_callback=append_to_list, - ).result() + ) assert result_list == [0] @@ -387,19 +375,20 @@ def __call__(self) -> asyncio.AbstractEventLoop: return user_method_loop user_callable_wrapper = _make_user_callable_wrapper(GetLoop) - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() request_metadata = _make_request_metadata() - user_code_loop = user_callable_wrapper.call_user_method( + user_code_loop = await user_callable_wrapper.call_user_method( request_metadata, tuple(), dict() - ).result() + ) assert isinstance(user_code_loop, asyncio.AbstractEventLoop) assert user_code_loop != main_loop # `check_health` method asserts that it runs on the correct loop. - user_callable_wrapper.call_user_health_check().result() + await user_callable_wrapper.call_user_health_check() -def test_callable_with_async_init(): +@pytest.mark.asyncio +async def test_callable_with_async_init(): class AsyncInitializer: async def __init__(self, msg: str): await asyncio.sleep(0.001) @@ -413,17 +402,16 @@ def __call__(self) -> str: AsyncInitializer, msg, ) - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() request_metadata = _make_request_metadata() assert ( - user_callable_wrapper.call_user_method( - request_metadata, tuple(), dict() - ).result() + await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) ) == msg +@pytest.mark.asyncio @pytest.mark.parametrize("async_del", [False, True]) -def test_destructor_only_called_once(async_del: bool): +async def test_destructor_only_called_once(async_del: bool): num_destructor_calls = 0 if async_del: @@ -443,12 +431,15 @@ def __del__(self) -> str: user_callable_wrapper = _make_user_callable_wrapper( DestroyerOfNothing, ) - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() # Call `call_destructor` many times in parallel; only the first one should actually # run the `__del__` method. - concurrent.futures.wait( - [user_callable_wrapper.call_destructor() for _ in range(100)] + await asyncio.gather( + *[ + asyncio.ensure_future(user_callable_wrapper.call_destructor()) + for _ in range(100) + ] ) assert num_destructor_calls == 1 @@ -470,22 +461,20 @@ async def __call__(self) -> str: user_callable_wrapper = _make_user_callable_wrapper( LoopBlocker, ) - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() request_metadata = _make_request_metadata() blocked_future = user_callable_wrapper.call_user_method( request_metadata, tuple(), dict() ) - _, pending = concurrent.futures.wait([blocked_future], timeout=0.01) + _, pending = await asyncio.wait([blocked_future], timeout=0.01) assert len(pending) == 1 for _ in range(100): # If this called something on the event loop, it'd be blocked. - # Instead, `user_callable_wrapper.call_user_health_check` returns None - # when there's no user health check configured. - assert user_callable_wrapper.call_user_health_check() is None + await user_callable_wrapper.call_user_health_check() sync_event.set() - assert blocked_future.result() == "Sorry I got stuck!" + assert await blocked_future == "Sorry I got stuck!" class gRPCClass: @@ -497,18 +486,19 @@ def stream(self, msg: serve_pb2.UserDefinedMessage): yield serve_pb2.UserDefinedResponse(greeting=f"Hello {msg.greeting} {i}!") -def test_grpc_unary_request(): +@pytest.mark.asyncio +async def test_grpc_unary_request(): user_callable_wrapper = _make_user_callable_wrapper(gRPCClass) - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() grpc_request = gRPCRequest( pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world")) ) request_metadata = _make_request_metadata(call_method="greet", is_grpc_request=True) - _, result_bytes = user_callable_wrapper.call_user_method( + _, result_bytes = await user_callable_wrapper.call_user_method( request_metadata, (grpc_request,), dict() - ).result() + ) assert isinstance(result_bytes, bytes) result = serve_pb2.UserDefinedResponse() @@ -517,9 +507,9 @@ def test_grpc_unary_request(): @pytest.mark.asyncio -def test_grpc_streaming_request(): +async def test_grpc_streaming_request(): user_callable_wrapper = _make_user_callable_wrapper(gRPCClass) - user_callable_wrapper.initialize_callable() + await user_callable_wrapper.initialize_callable() grpc_request = gRPCRequest( pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world")) @@ -533,12 +523,12 @@ async def append_to_list(item: Any): request_metadata = _make_request_metadata( call_method="stream", is_grpc_request=True, is_streaming=True ) - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (grpc_request,), dict(), generator_result_callback=append_to_list, - ).result() + ) assert len(result_list) == 10 for i, (_, result_bytes) in enumerate(result_list): @@ -566,10 +556,11 @@ async def handle_root(self, request: Request) -> str: return PlainTextResponse(f"Hello {msg}!") +@pytest.mark.asyncio @pytest.mark.parametrize("callable", [RawRequestHandler, FastAPIRequestHandler]) -def test_http_handler(callable: Callable, monkeypatch): +async def test_http_handler(callable: Callable, monkeypatch): user_callable_wrapper = _make_user_callable_wrapper(callable) - user_callable_wrapper.initialize_callable().result() + await user_callable_wrapper.initialize_callable() @dataclass class MockReplicaContext: @@ -620,12 +611,12 @@ async def append_to_list(item: Any): result_list.append(item) request_metadata = _make_request_metadata(is_http_request=True, is_streaming=True) - user_callable_wrapper.call_user_method( + await user_callable_wrapper.call_user_method( request_metadata, (http_request,), dict(), generator_result_callback=append_to_list, - ).result() + ) assert result_list[0]["type"] == "http.response.start" assert result_list[0]["status"] == 200