From 5deff4ae7597e5c49b4fbf0de5c1fa1bdb3bcf7b Mon Sep 17 00:00:00 2001 From: Cuong Nguyen <128072568+can-anyscale@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:58:20 -0700 Subject: [PATCH] Revert "Revert "[serve] Refactor `UserCallableWrapper` to return `concurrent.futures.Future`"" (#48472) Reverts ray-project/ray#48468 Revert the revert; blame PR was incorrect Signed-off-by: JP-sDEV --- python/ray/serve/_private/replica.py | 120 +++++----- .../tests/unit/test_user_callable_wrapper.py | 213 +++++++++--------- 2 files changed, 177 insertions(+), 156 deletions(-) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 5717182b15ab..c0bbfdf9d9bd 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures import inspect import logging import os @@ -466,11 +467,13 @@ async def _call_user_generator( async def _enqueue_thread_safe(item: Any): self._event_loop.call_soon_threadsafe(result_queue.put_nowait, item) - call_user_method_future = self._user_callable_wrapper.call_user_method( - request_metadata, - request_args, - request_kwargs, - generator_result_callback=_enqueue_thread_safe, + call_user_method_future = asyncio.wrap_future( + self._user_callable_wrapper.call_user_method( + request_metadata, + request_args, + request_kwargs, + generator_result_callback=_enqueue_thread_safe, + ) ) while True: @@ -521,8 +524,10 @@ async def handle_request( """Entrypoint for `stream=False` calls.""" request_metadata = pickle.loads(pickled_request_metadata) with self._wrap_user_method_call(request_metadata, request_args): - return await self._user_callable_wrapper.call_user_method( - request_metadata, request_args, request_kwargs + return await asyncio.wrap_future( + self._user_callable_wrapper.call_user_method( + request_metadata, request_args, request_kwargs + ) ) async def handle_request_streaming( @@ -593,8 +598,10 @@ async def handle_request_with_rejection( ): yield result else: - yield await self._user_callable_wrapper.call_user_method( - request_metadata, request_args, request_kwargs + yield await asyncio.wrap_future( + self._user_callable_wrapper.call_user_method( + request_metadata, request_args, request_kwargs + ) ) async def handle_request_from_java( @@ -617,8 +624,10 @@ async def handle_request_from_java( route=proto.route, ) with self._wrap_user_method_call(request_metadata, request_args): - return await self._user_callable_wrapper.call_user_method( - request_metadata, request_args, request_kwargs + return await asyncio.wrap_future( + self._user_callable_wrapper.call_user_method( + request_metadata, request_args, request_kwargs + ) ) async def is_allocated(self) -> str: @@ -663,16 +672,18 @@ async def initialize_and_get_metadata( async with self._user_callable_initialized_lock: initialization_start_time = time.time() if not self._user_callable_initialized: - self._user_callable_asgi_app = ( - await self._user_callable_wrapper.initialize_callable() + self._user_callable_asgi_app = await asyncio.wrap_future( + self._user_callable_wrapper.initialize_callable() ) self._user_callable_initialized = True self._set_internal_replica_context( servable_object=self._user_callable_wrapper.user_callable ) if deployment_config: - await self._user_callable_wrapper.call_reconfigure( - deployment_config.user_config + await asyncio.wrap_future( + self._user_callable_wrapper.call_reconfigure( + deployment_config.user_config + ) ) # A new replica should not be considered healthy until it passes @@ -713,8 +724,10 @@ async def reconfigure( self._configure_logger_and_profilers(deployment_config.logging_config) if user_config_changed: - await self._user_callable_wrapper.call_reconfigure( - deployment_config.user_config + await asyncio.wrap_future( + self._user_callable_wrapper.call_reconfigure( + deployment_config.user_config + ) ) # We need to update internal replica context to reflect the new @@ -789,7 +802,7 @@ async def perform_graceful_shutdown(self): await self._drain_ongoing_requests() try: - await self._user_callable_wrapper.call_destructor() + await asyncio.wrap_future(self._user_callable_wrapper.call_destructor()) except: # noqa: E722 # We catch a blanket exception since the constructor may still be # running, so instance variables used by the destructor may not exist. @@ -804,20 +817,18 @@ async def perform_graceful_shutdown(self): await self._metrics_manager.shutdown() async def check_health(self): - await self._user_callable_wrapper.call_user_health_check() + # If there's no user-defined health check, nothing runs on the user code event + # loop and no future is returned. + f: Optional[ + concurrent.futures.Future + ] = self._user_callable_wrapper.call_user_health_check() + if f is not None: + await asyncio.wrap_future(f) class UserCallableWrapper: """Wraps a user-provided callable that is used to handle requests to a replica.""" - # All interactions with user code run on this loop to avoid blocking the replica's - # main event loop. - # NOTE(edoakes): this is a class variable rather than an instance variable to - # enable writing the `_run_on_user_code_event_loop` decorator method (the decorator - # doesn't have access to `self` at class definition time). - _user_code_event_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - _user_code_event_loop_thread: Optional[threading.Thread] = None - def __init__( self, deployment_def: Callable, @@ -842,38 +853,37 @@ def __init__( # Will be populated in `initialize_callable`. self._callable = None - # Start the `_user_code_event_loop_thread` singleton if needed. - if self._user_code_event_loop_thread is None: + # All interactions with user code run on this loop to avoid blocking the + # replica's main event loop. + self._user_code_event_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - def _run_user_code_event_loop(): - # Required so that calls to get the current running event loop work - # properly in user code. - asyncio.set_event_loop(self._user_code_event_loop) - self._user_code_event_loop.run_forever() + def _run_user_code_event_loop(): + # Required so that calls to get the current running event loop work + # properly in user code. + asyncio.set_event_loop(self._user_code_event_loop) + self._user_code_event_loop.run_forever() - self._user_code_event_loop_thread = threading.Thread( - daemon=True, - target=_run_user_code_event_loop, - ) - self._user_code_event_loop_thread.start() + self._user_code_event_loop_thread = threading.Thread( + daemon=True, + target=_run_user_code_event_loop, + ) + self._user_code_event_loop_thread.start() - def _run_on_user_code_event_loop(f: Callable): + def _run_on_user_code_event_loop(f: Callable) -> Callable: """Decorator to run a coroutine method on the user code event loop. - The method will be modified to be a sync function that returns an - `asyncio.Future`. + The method will be modified to be a sync function that returns a + `concurrent.futures.Future`. """ assert inspect.iscoroutinefunction( f ), "_run_on_user_code_event_loop can only be used on coroutine functions." @wraps(f) - def wrapper(*args, **kwargs) -> asyncio.Future: - return asyncio.wrap_future( - asyncio.run_coroutine_threadsafe( - f(*args, **kwargs), - UserCallableWrapper._user_code_event_loop, - ) + def wrapper(self, *args, **kwargs) -> concurrent.futures.Future: + return asyncio.run_coroutine_threadsafe( + f(self, *args, **kwargs), + self._user_code_event_loop, ) return wrapper @@ -985,17 +995,13 @@ async def initialize_callable(self) -> Optional[ASGIApp]: else None ) - @_run_on_user_code_event_loop - async def _call_user_health_check(self): - await self._call_func_or_gen(self._user_health_check) - def _raise_if_not_initialized(self, method_name: str): if self._callable is None: raise RuntimeError( "`initialize_callable` must be called before `{method_name}`." ) - async def call_user_health_check(self): + def call_user_health_check(self) -> Optional[concurrent.futures.Future]: self._raise_if_not_initialized("call_user_health_check") # If the user provided a health check, call it on the user code thread. If user @@ -1004,7 +1010,13 @@ async def call_user_health_check(self): # To avoid this issue for basic cases without a user-defined health check, skip # interacting with the user callable entirely. if self._user_health_check is not None: - return await self._call_user_health_check() + return self._call_user_health_check() + + return None + + @_run_on_user_code_event_loop + async def _call_user_health_check(self): + await self._call_func_or_gen(self._user_health_check) @_run_on_user_code_event_loop async def call_reconfigure(self, user_config: Any): diff --git a/python/ray/serve/tests/unit/test_user_callable_wrapper.py b/python/ray/serve/tests/unit/test_user_callable_wrapper.py index 53f91bda1ebb..d341c62db016 100644 --- a/python/ray/serve/tests/unit/test_user_callable_wrapper.py +++ b/python/ray/serve/tests/unit/test_user_callable_wrapper.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures import pickle import sys import threading @@ -123,93 +124,102 @@ def _make_request_metadata( ) -@pytest.mark.asyncio -async def test_calling_initialize_twice(): +def test_calling_initialize_twice(): user_callable_wrapper = _make_user_callable_wrapper() - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() assert isinstance(user_callable_wrapper.user_callable, BasicClass) with pytest.raises(RuntimeError): - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() -@pytest.mark.asyncio -async def test_calling_methods_before_initialize(): +def test_calling_methods_before_initialize(): user_callable_wrapper = _make_user_callable_wrapper() with pytest.raises(RuntimeError): - await user_callable_wrapper.call_user_method(None, tuple(), dict()) + user_callable_wrapper.call_user_method(None, tuple(), dict()).result() with pytest.raises(RuntimeError): - await user_callable_wrapper.call_user_health_check() + user_callable_wrapper.call_user_health_check().result() with pytest.raises(RuntimeError): - await user_callable_wrapper.call_reconfigure(None) + user_callable_wrapper.call_reconfigure(None).result() -@pytest.mark.asyncio -async def test_basic_class_callable(): +def test_basic_class_callable(): user_callable_wrapper = _make_user_callable_wrapper() - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() # Call non-generator method with is_streaming. request_metadata = _make_request_metadata(is_streaming=True) with pytest.raises(RayTaskError, match="did not return a generator."): - await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) + user_callable_wrapper.call_user_method( + request_metadata, tuple(), dict() + ).result() # Test calling default sync `__call__` method. request_metadata = _make_request_metadata() assert ( - await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) + user_callable_wrapper.call_user_method( + request_metadata, tuple(), dict() + ).result() ) == "hi" assert ( - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, ("-arg",), dict() - ) - ) == "hi-arg" + ).result() + == "hi-arg" + ) assert ( - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, tuple(), {"suffix": "-kwarg"} - ) - ) == "hi-kwarg" + ).result() + == "hi-kwarg" + ) with pytest.raises(RayTaskError, match="uh-oh"): - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, tuple(), {"raise_exception": True} - ) + ).result() # Call non-generator async method with is_streaming. request_metadata = _make_request_metadata( call_method="call_async", is_streaming=True ) with pytest.raises(RayTaskError, match="did not return a generator."): - await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) + user_callable_wrapper.call_user_method( + request_metadata, tuple(), dict() + ).result() # Test calling `call_async` method. request_metadata = _make_request_metadata(call_method="call_async") assert ( - await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) - ) == "hi" + user_callable_wrapper.call_user_method( + request_metadata, tuple(), dict() + ).result() + == "hi" + ) assert ( - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, ("-arg",), dict() - ) - ) == "hi-arg" + ).result() + == "hi-arg" + ) assert ( - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, tuple(), {"suffix": "-kwarg"} - ) - ) == "hi-kwarg" + ).result() + == "hi-kwarg" + ) with pytest.raises(RayTaskError, match="uh-oh"): - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, tuple(), {"raise_exception": True} - ) + ).result() -@pytest.mark.asyncio -async def test_basic_class_callable_generators(): +def test_basic_class_callable_generators(): user_callable_wrapper = _make_user_callable_wrapper() - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() result_list = [] @@ -223,28 +233,28 @@ async def append_to_list(item: Any): with pytest.raises( RayTaskError, match="Method 'call_generator' returned a generator." ): - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ) + ).result() # Call sync generator. request_metadata = _make_request_metadata( call_method="call_generator", is_streaming=True ) - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ) + ).result() assert result_list == list(range(10)) result_list.clear() # Call sync generator raising exception. with pytest.raises(RayTaskError, match="uh-oh"): - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (10,), {"raise_exception": True}, generator_result_callback=append_to_list, - ) + ).result() assert result_list == [0] result_list.clear() @@ -255,67 +265,69 @@ async def append_to_list(item: Any): with pytest.raises( RayTaskError, match="Method 'call_async_generator' returned a generator." ): - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ) + ).result() # Call async generator. request_metadata = _make_request_metadata( call_method="call_async_generator", is_streaming=True ) - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ) + ).result() assert result_list == list(range(10)) result_list.clear() # Call async generator raising exception. with pytest.raises(RayTaskError, match="uh-oh"): - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (10,), {"raise_exception": True}, generator_result_callback=append_to_list, - ) + ).result() assert result_list == [0] -@pytest.mark.asyncio @pytest.mark.parametrize("fn", [basic_sync_function, basic_async_function]) -async def test_basic_function_callable(fn: Callable): +def test_basic_function_callable(fn: Callable): user_callable_wrapper = _make_user_callable_wrapper(fn) - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() # Call non-generator function with is_streaming. request_metadata = _make_request_metadata(is_streaming=True) with pytest.raises(RayTaskError, match="did not return a generator."): - await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) + user_callable_wrapper.call_user_method( + request_metadata, tuple(), dict() + ).result() request_metadata = _make_request_metadata() assert ( - await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) + user_callable_wrapper.call_user_method( + request_metadata, tuple(), dict() + ).result() ) == "hi" assert ( - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, ("-arg",), dict() - ) + ).result() ) == "hi-arg" assert ( - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, tuple(), {"suffix": "-kwarg"} - ) + ).result() ) == "hi-kwarg" with pytest.raises(RayTaskError, match="uh-oh"): - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, tuple(), {"raise_exception": True} - ) + ).result() -@pytest.mark.asyncio @pytest.mark.parametrize("fn", [basic_sync_generator, basic_async_generator]) -async def test_basic_function_callable_generators(fn: Callable): +def test_basic_function_callable_generators(fn: Callable): user_callable_wrapper = _make_user_callable_wrapper(fn) - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() result_list = [] @@ -327,28 +339,28 @@ async def append_to_list(item: Any): with pytest.raises( RayTaskError, match=f"Method '{fn.__name__}' returned a generator." ): - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ) + ).result() # Call generator function. request_metadata = _make_request_metadata( call_method="call_generator", is_streaming=True ) - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (10,), dict(), generator_result_callback=append_to_list - ) + ).result() assert result_list == list(range(10)) result_list.clear() # Call generator function raising exception. with pytest.raises(RayTaskError, match="uh-oh"): - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (10,), {"raise_exception": True}, generator_result_callback=append_to_list, - ) + ).result() assert result_list == [0] @@ -375,20 +387,19 @@ def __call__(self) -> asyncio.AbstractEventLoop: return user_method_loop user_callable_wrapper = _make_user_callable_wrapper(GetLoop) - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() request_metadata = _make_request_metadata() - user_code_loop = await user_callable_wrapper.call_user_method( + user_code_loop = user_callable_wrapper.call_user_method( request_metadata, tuple(), dict() - ) + ).result() assert isinstance(user_code_loop, asyncio.AbstractEventLoop) assert user_code_loop != main_loop # `check_health` method asserts that it runs on the correct loop. - await user_callable_wrapper.call_user_health_check() + user_callable_wrapper.call_user_health_check().result() -@pytest.mark.asyncio -async def test_callable_with_async_init(): +def test_callable_with_async_init(): class AsyncInitializer: async def __init__(self, msg: str): await asyncio.sleep(0.001) @@ -402,16 +413,17 @@ def __call__(self) -> str: AsyncInitializer, msg, ) - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() request_metadata = _make_request_metadata() assert ( - await user_callable_wrapper.call_user_method(request_metadata, tuple(), dict()) + user_callable_wrapper.call_user_method( + request_metadata, tuple(), dict() + ).result() ) == msg -@pytest.mark.asyncio @pytest.mark.parametrize("async_del", [False, True]) -async def test_destructor_only_called_once(async_del: bool): +def test_destructor_only_called_once(async_del: bool): num_destructor_calls = 0 if async_del: @@ -431,15 +443,12 @@ def __del__(self) -> str: user_callable_wrapper = _make_user_callable_wrapper( DestroyerOfNothing, ) - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() # Call `call_destructor` many times in parallel; only the first one should actually # run the `__del__` method. - await asyncio.gather( - *[ - asyncio.ensure_future(user_callable_wrapper.call_destructor()) - for _ in range(100) - ] + concurrent.futures.wait( + [user_callable_wrapper.call_destructor() for _ in range(100)] ) assert num_destructor_calls == 1 @@ -461,20 +470,22 @@ async def __call__(self) -> str: user_callable_wrapper = _make_user_callable_wrapper( LoopBlocker, ) - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() request_metadata = _make_request_metadata() blocked_future = user_callable_wrapper.call_user_method( request_metadata, tuple(), dict() ) - _, pending = await asyncio.wait([blocked_future], timeout=0.01) + _, pending = concurrent.futures.wait([blocked_future], timeout=0.01) assert len(pending) == 1 for _ in range(100): # If this called something on the event loop, it'd be blocked. - await user_callable_wrapper.call_user_health_check() + # Instead, `user_callable_wrapper.call_user_health_check` returns None + # when there's no user health check configured. + assert user_callable_wrapper.call_user_health_check() is None sync_event.set() - assert await blocked_future == "Sorry I got stuck!" + assert blocked_future.result() == "Sorry I got stuck!" class gRPCClass: @@ -486,19 +497,18 @@ def stream(self, msg: serve_pb2.UserDefinedMessage): yield serve_pb2.UserDefinedResponse(greeting=f"Hello {msg.greeting} {i}!") -@pytest.mark.asyncio -async def test_grpc_unary_request(): +def test_grpc_unary_request(): user_callable_wrapper = _make_user_callable_wrapper(gRPCClass) - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() grpc_request = gRPCRequest( pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world")) ) request_metadata = _make_request_metadata(call_method="greet", is_grpc_request=True) - _, result_bytes = await user_callable_wrapper.call_user_method( + _, result_bytes = user_callable_wrapper.call_user_method( request_metadata, (grpc_request,), dict() - ) + ).result() assert isinstance(result_bytes, bytes) result = serve_pb2.UserDefinedResponse() @@ -507,9 +517,9 @@ async def test_grpc_unary_request(): @pytest.mark.asyncio -async def test_grpc_streaming_request(): +def test_grpc_streaming_request(): user_callable_wrapper = _make_user_callable_wrapper(gRPCClass) - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable() grpc_request = gRPCRequest( pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world")) @@ -523,12 +533,12 @@ async def append_to_list(item: Any): request_metadata = _make_request_metadata( call_method="stream", is_grpc_request=True, is_streaming=True ) - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (grpc_request,), dict(), generator_result_callback=append_to_list, - ) + ).result() assert len(result_list) == 10 for i, (_, result_bytes) in enumerate(result_list): @@ -556,11 +566,10 @@ async def handle_root(self, request: Request) -> str: return PlainTextResponse(f"Hello {msg}!") -@pytest.mark.asyncio @pytest.mark.parametrize("callable", [RawRequestHandler, FastAPIRequestHandler]) -async def test_http_handler(callable: Callable, monkeypatch): +def test_http_handler(callable: Callable, monkeypatch): user_callable_wrapper = _make_user_callable_wrapper(callable) - await user_callable_wrapper.initialize_callable() + user_callable_wrapper.initialize_callable().result() @dataclass class MockReplicaContext: @@ -611,12 +620,12 @@ async def append_to_list(item: Any): result_list.append(item) request_metadata = _make_request_metadata(is_http_request=True, is_streaming=True) - await user_callable_wrapper.call_user_method( + user_callable_wrapper.call_user_method( request_metadata, (http_request,), dict(), generator_result_callback=append_to_list, - ) + ).result() assert result_list[0]["type"] == "http.response.start" assert result_list[0]["status"] == 200