From 267b14ecbf6f550c1bc54e7fc8d35e579177236d Mon Sep 17 00:00:00 2001 From: Gene Der Su Date: Fri, 7 Jul 2023 11:50:13 -0700 Subject: [PATCH] [Serve] Fix serve non atomic shutdown (#36927) Currently we are relying on the client to wait for all the resources before shutting off the controller. This caused the issue for when they interrupt the process and can cause incomplete shutdown. In this PR we moved the shutdown logic into the event loop which would be triggered by a `_shutting_down` flag on the controller. So even if the client interrupted the process, the controller will continue to shutdown all the resources and then kill itself. --- python/ray/serve/BUILD | 8 ++ .../ray/serve/_private/application_state.py | 25 ++++- python/ray/serve/_private/client.py | 59 ++--------- python/ray/serve/_private/deployment_state.py | 15 ++- python/ray/serve/_private/endpoint_state.py | 8 ++ python/ray/serve/_private/http_state.py | 32 ++++++ python/ray/serve/controller.py | 98 ++++++++++++++++- .../ray/serve/tests/test_application_state.py | 46 ++++++++ python/ray/serve/tests/test_cli.py | 6 +- .../ray/serve/tests/test_deployment_state.py | 6 ++ python/ray/serve/tests/test_endpoint_state.py | 65 ++++++++++++ python/ray/serve/tests/test_http_state.py | 63 +++++++---- python/ray/serve/tests/test_standalone3.py | 100 ++++++++++++++++++ 13 files changed, 457 insertions(+), 74 deletions(-) create mode 100644 python/ray/serve/tests/test_endpoint_state.py diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 4b250657a556..45f2cfdc75b8 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -681,3 +681,11 @@ py_test( tags = ["exclusive", "team:serve"], deps = [":serve_lib"], ) + +py_test( + name = "test_endpoint_state", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive", "team:serve"], + deps = [":serve_lib"], +) diff --git a/python/ray/serve/_private/application_state.py b/python/ray/serve/_private/application_state.py index 89f3d5ef7f65..815cda6631fd 100644 --- a/python/ray/serve/_private/application_state.py +++ b/python/ray/serve/_private/application_state.py @@ -177,9 +177,20 @@ def _delete_deployment(self, name): def delete(self): """Delete the application""" - logger.info(f"Deleting application '{self._name}'") + logger.info( + f"Deleting application '{self._name}'", + extra={"log_to_stderr": False}, + ) self._set_target_state(deleting=True) + def is_deleted(self) -> bool: + """Check whether the application is already deleted. + + For an application to be considered deleted, the target state has to be set to + deleting and all deployments have to be deleted. + """ + return self._target_state.deleting and len(self._get_live_deployments()) == 0 + def apply_deployment_info( self, deployment_name: str, deployment_info: DeploymentInfo ) -> None: @@ -412,7 +423,7 @@ def update(self) -> bool: # Check if app is ready to be deleted if self._target_state.deleting: - return len(self._get_live_deployments()) == 0 + return self.is_deleted() return False def get_checkpoint_data(self) -> ApplicationTargetState: @@ -639,6 +650,16 @@ def shutdown(self) -> None: for app_state in self._application_states.values(): app_state.delete() + def is_ready_for_shutdown(self) -> bool: + """Return whether all applications have shut down. + + Iterate through all application states and check if all their applications + are deleted. + """ + return all( + app_state.is_deleted() for app_state in self._application_states.values() + ) + def _save_checkpoint_func( self, *, writeahead_checkpoints: Optional[Dict[str, ApplicationTargetState]] ) -> None: diff --git a/python/ray/serve/_private/client.py b/python/ray/serve/_private/client.py index 7c9a61358f27..31acdfcb4cb9 100644 --- a/python/ray/serve/_private/client.py +++ b/python/ray/serve/_private/client.py @@ -20,7 +20,6 @@ CLIENT_POLLING_INTERVAL_S, CLIENT_CHECK_CREATION_POLLING_INTERVAL_S, MAX_CACHED_HANDLES, - SERVE_NAMESPACE, SERVE_DEFAULT_APP_NAME, ) from ray.serve._private.deploy_utils import get_deploy_args @@ -95,7 +94,7 @@ def __del__(self): def __reduce__(self): raise RayServeException(("Ray Serve client cannot be serialized.")) - def shutdown(self) -> None: + def shutdown(self, timeout_s: float = 30.0) -> None: """Completely shut down the connected Serve instance. Shuts down all processes and deletes all state associated with the @@ -107,53 +106,17 @@ def shutdown(self) -> None: del self.handle_cache[k] if ray.is_initialized() and not self._shutdown: - ray.get(self._controller.shutdown.remote()) - self._wait_for_deployments_shutdown() - - ray.kill(self._controller, no_restart=True) - - # Wait for the named actor entry gets removed as well. - started = time.time() - while True: - try: - ray.get_actor(self._controller_name, namespace=SERVE_NAMESPACE) - if time.time() - started > 5: - logger.warning( - "Waited 5s for Serve to shutdown gracefully but " - "the controller is still not cleaned up. " - "You can ignore this warning if you are shutting " - "down the Ray cluster." - ) - break - except ValueError: # actor name is removed - break - - self._shutdown = True - - def _wait_for_deployments_shutdown(self, timeout_s: int = 60): - """Waits for all deployments to be shut down and deleted. - - Raises TimeoutError if this doesn't happen before timeout_s. - """ - start = time.time() - while time.time() - start < timeout_s: - deployment_statuses = self.get_all_deployment_statuses() - if len(deployment_statuses) == 0: - break - else: - logger.debug( - f"Waiting for shutdown, {len(deployment_statuses)} " - "deployments still alive." + try: + ray.get(self._controller.graceful_shutdown.remote(), timeout=timeout_s) + except ray.exceptions.RayActorError: + # Controller has been shut down. + pass + except TimeoutError: + logger.warning( + f"Controller failed to shut down within {timeout_s}s. " + "Check controller logs for more details." ) - time.sleep(CLIENT_POLLING_INTERVAL_S) - else: - live_names = [ - deployment_status.name for deployment_status in deployment_statuses - ] - raise TimeoutError( - f"Shutdown didn't complete after {timeout_s}s. " - f"Deployments still alive: {live_names}." - ) + self._shutdown = True def _wait_for_deployment_healthy(self, name: str, timeout_s: int = -1): """Waits for the named deployment to enter "HEALTHY" status. diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 603c30957f61..cc030a52df69 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1279,7 +1279,10 @@ def _set_target_state_deleting(self) -> None: self._curr_status_info = DeploymentStatusInfo( self._name, DeploymentStatus.UPDATING ) - logger.info(f"Deleting deployment {self._name}.") + logger.info( + f"Deleting deployment {self._name}.", + extra={"log_to_stderr": False}, + ) def _set_target_state(self, target_info: DeploymentInfo) -> None: """Set the target state for the deployment to the provided info.""" @@ -2284,6 +2287,16 @@ def shutdown(self): # TODO(jiaodong): Need to add some logic to prevent new replicas # from being created once shutdown signal is sent. + def is_ready_for_shutdown(self) -> bool: + """Return whether all deployments are shutdown. + + Check there are no deployment states and no checkpoints. + """ + return ( + len(self._deployment_states) == 0 + and self._kv_store.get(CHECKPOINT_KEY) is None + ) + def _save_checkpoint_func( self, *, writeahead_checkpoints: Optional[Dict[str, Tuple]] ) -> None: diff --git a/python/ray/serve/_private/endpoint_state.py b/python/ray/serve/_private/endpoint_state.py index e6662892173f..949e2ad08537 100644 --- a/python/ray/serve/_private/endpoint_state.py +++ b/python/ray/serve/_private/endpoint_state.py @@ -34,6 +34,14 @@ def __init__(self, kv_store: KVStoreBase, long_poll_host: LongPollHost): def shutdown(self): self._kv_store.delete(CHECKPOINT_KEY) + def is_ready_for_shutdown(self) -> bool: + """Returns whether the endpoint checkpoint has been deleted. + + Get the endpoint checkpoint from the kv store. If it is None, then it has been + deleted. + """ + return self._kv_store.get(CHECKPOINT_KEY) is None + def _checkpoint(self): self._kv_store.put(CHECKPOINT_KEY, cloudpickle.dumps(self._endpoints)) diff --git a/python/ray/serve/_private/http_state.py b/python/ray/serve/_private/http_state.py index 45cd98003e14..71a4ea97e8ae 100644 --- a/python/ray/serve/_private/http_state.py +++ b/python/ray/serve/_private/http_state.py @@ -230,6 +230,27 @@ def shutdown(self): self._shutting_down = True ray.kill(self.actor_handle, no_restart=True) + def is_ready_for_shutdown(self) -> bool: + """Return whether the HTTP proxy actor is shutdown. + + For an HTTP proxy actor to be considered shutdown, it must be marked as + _shutting_down and the actor must be dead. If the actor is dead, the health + check will return RayActorError. + """ + if not self._shutting_down: + return False + + try: + ray.get(self._actor_handle.check_health.remote(), timeout=0.001) + except ray.exceptions.RayActorError: + # The actor is dead, so it's ready for shutdown. + return True + except ray.exceptions.GetTimeoutError: + # The actor is still alive, so it's not ready for shutdown. + return False + + return False + class HTTPState: """Manages all state for HTTP proxies in the system. @@ -269,6 +290,17 @@ def shutdown(self) -> None: for proxy_state in self._proxy_states.values(): proxy_state.shutdown() + def is_ready_for_shutdown(self) -> bool: + """Return whether all proxies are shutdown. + + Iterate through all proxy states and check if all their proxy actors + are shutdown. + """ + return all( + proxy_state.is_ready_for_shutdown() + for proxy_state in self._proxy_states.values() + ) + def get_config(self): return self._config diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 1d1b224470c8..7b718d080211 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -187,6 +187,9 @@ async def __init__( worker_id=ray.get_runtime_context().get_worker_id(), log_file_path=get_component_logger_file_path(), ) + self._shutting_down = False + self._shutdown = asyncio.Event() + self._shutdown_start_time = None run_background_task(self.run_control_loop()) @@ -307,6 +310,12 @@ async def run_control_loop(self) -> None: recovering_timeout = RECOVERING_LONG_POLL_BROADCAST_TIMEOUT_S start_time = time.time() while True: + if self._shutting_down: + try: + self.shutdown() + except Exception: + logger.exception("Exception during shutdown.") + if ( not self.done_recovering_event.is_set() and time.time() - start_time > recovering_timeout @@ -439,8 +448,29 @@ def get_root_url(self): ) return http_config.root_url + def config_checkpoint_deleted(self) -> bool: + """Returns whether the config checkpoint has been deleted. + + Get the config checkpoint from the kv store. If it is None, then it has been + deleted. + """ + return self.kv_store.get(CONFIG_CHECKPOINT_KEY) is None + def shutdown(self): - """Shuts down the serve instance completely.""" + """Shuts down the serve instance completely. + + This method will only be triggered when `self._shutting_down` is true. It + deletes the kv store for config checkpoints, sets application state to deleting, + delete all deployments, and shuts down all HTTP proxies. Once all these + resources are released, it then kills the controller actor. + """ + if not self._shutting_down: + return + + if self._shutdown_start_time is None: + self._shutdown_start_time = time.time() + + logger.info("Controller shutdown started!", extra={"log_to_stderr": False}) self.kv_store.delete(CONFIG_CHECKPOINT_KEY) self.application_state_manager.shutdown() self.deployment_state_manager.shutdown() @@ -448,6 +478,54 @@ def shutdown(self): if self.http_state: self.http_state.shutdown() + config_checkpoint_deleted = self.config_checkpoint_deleted() + application_is_shutdown = self.application_state_manager.is_ready_for_shutdown() + deployment_is_shutdown = self.deployment_state_manager.is_ready_for_shutdown() + endpoint_is_shutdown = self.endpoint_state.is_ready_for_shutdown() + http_state_is_shutdown = ( + self.http_state is None or self.http_state.is_ready_for_shutdown() + ) + if ( + config_checkpoint_deleted + and application_is_shutdown + and deployment_is_shutdown + and endpoint_is_shutdown + and http_state_is_shutdown + ): + logger.warning( + "All resources have shut down, shutting down controller!", + extra={"log_to_stderr": False}, + ) + _controller_actor = ray.get_runtime_context().current_actor + self._shutdown.set() + ray.kill(_controller_actor, no_restart=True) + elif time.time() - self._shutdown_start_time > 10: + if not config_checkpoint_deleted: + logger.warning( + f"{CONFIG_CHECKPOINT_KEY} not yet deleted", + extra={"log_to_stderr": False}, + ) + if not application_is_shutdown: + logger.warning( + "application not yet shutdown", + extra={"log_to_stderr": False}, + ) + if not deployment_is_shutdown: + logger.warning( + "deployment not yet shutdown", + extra={"log_to_stderr": False}, + ) + if not endpoint_is_shutdown: + logger.warning( + "endpoint not yet shutdown", + extra={"log_to_stderr": False}, + ) + if not http_state_is_shutdown: + logger.warning( + "http_state not yet shutdown", + extra={"log_to_stderr": False}, + ) + def deploy( self, name: str, @@ -835,6 +913,20 @@ def record_multiplexed_replica_info(self, info: MultiplexedReplicaInfo): """ self.deployment_state_manager.record_multiplexed_replica_info(info) + async def graceful_shutdown(self, wait: bool = True): + """Set the shutting down flag on controller to signal shutdown in + run_control_loop(). + + This is used to signal to the controller that it should proceed with shutdown + process, so it can shut down gracefully. It also waits until the shutdown + event is triggered if wait is true. + """ + self._shutting_down = True + if not wait: + return + + await self._shutdown.wait() + @ray.remote(num_cpus=0, max_calls=1) def deploy_serve_application( @@ -943,7 +1035,7 @@ def __init__( http_proxy_port: int = 8000, ): try: - self._controller = ray.get_actor(controller_name, namespace="serve") + self._controller = ray.get_actor(controller_name, namespace=SERVE_NAMESPACE) except ValueError: self._controller = None if self._controller is None: @@ -956,7 +1048,7 @@ def __init__( max_restarts=-1, max_task_retries=-1, resources={HEAD_NODE_RESOURCE_NAME: 0.001}, - namespace="serve", + namespace=SERVE_NAMESPACE, max_concurrency=CONTROLLER_MAX_CONCURRENCY, ).remote( controller_name, diff --git a/python/ray/serve/tests/test_application_state.py b/python/ray/serve/tests/test_application_state.py index 0f7afc16098c..8bdb18e41461 100644 --- a/python/ray/serve/tests/test_application_state.py +++ b/python/ray/serve/tests/test_application_state.py @@ -658,5 +658,51 @@ def test_recover_during_update(mocked_application_state_manager): assert app_state.status == ApplicationStatus.RUNNING +def test_is_ready_for_shutdown(mocked_application_state_manager): + """Test `is_ready_for_shutdown()` returns the correct state. + + When shutting down applications before deployments are deleted, application state + `is_deleted()` should return False and `is_ready_for_shutdown()` should return + False. When shutting down applications after deployments are deleted, application + state `is_deleted()` should return True and `is_ready_for_shutdown()` should return + True. + """ + ( + app_state_manager, + deployment_state_manager, + kv_store, + ) = mocked_application_state_manager + app_name = "test_app" + deployment_name = "d1" + + # DEPLOY application with deployment "d1" + params = deployment_params(deployment_name) + app_state_manager.apply_deployment_args(app_name, [params]) + app_state = app_state_manager._application_states[app_name] + assert app_state.status == ApplicationStatus.DEPLOYING + + # Once deployment is healthy, app should be running + app_state_manager.update() + assert deployment_state_manager.get_deployment(deployment_name) + deployment_state_manager.set_deployment_healthy(deployment_name) + app_state_manager.update() + assert app_state.status == ApplicationStatus.RUNNING + + # When shutting down applications before deployments are deleted, application state + # `is_deleted()` should return False and `is_ready_for_shutdown()` should return + # False + app_state_manager.shutdown() + assert not app_state.is_deleted() + assert not app_state_manager.is_ready_for_shutdown() + + # When shutting down applications after deployments are deleted, application state + # `is_deleted()` should return True and `is_ready_for_shutdown()` should return True + deployment_state_manager.delete_deployment(deployment_name) + deployment_state_manager.set_deployment_deleted(deployment_name) + app_state_manager.update() + assert app_state.is_deleted() + assert app_state_manager.is_ready_for_shutdown() + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_cli.py b/python/ray/serve/tests/test_cli.py index 21a3a633592e..1526588d367f 100644 --- a/python/ray/serve/tests/test_cli.py +++ b/python/ray/serve/tests/test_cli.py @@ -828,8 +828,9 @@ def parrot(request): parrot_node = parrot.bind() +@pytest.mark.parametrize("number_of_kill_signals", (1, 2)) @pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.") -def test_run_application(ray_start_stop): +def test_run_application(ray_start_stop, number_of_kill_signals): """Deploys valid config file and import path via `serve run`.""" # Deploy via config file @@ -849,7 +850,8 @@ def test_run_application(ray_start_stop): ) print("Run successful! Deployments are live and reachable over HTTP. Killing run.") - p.send_signal(signal.SIGINT) # Equivalent to ctrl-C + for _ in range(number_of_kill_signals): + p.send_signal(signal.SIGINT) # Equivalent to ctrl-C p.wait() with pytest.raises(requests.exceptions.ConnectionError): requests.post("http://localhost:8000/", json=["ADD", 0]).json() diff --git a/python/ray/serve/tests/test_deployment_state.py b/python/ray/serve/tests/test_deployment_state.py index 97bddee4577c..ef3dd6ebdb23 100644 --- a/python/ray/serve/tests/test_deployment_state.py +++ b/python/ray/serve/tests/test_deployment_state.py @@ -2525,6 +2525,9 @@ def test_shutdown(mock_deployment_state_manager, is_driver_deployment): # Test shutdown flow assert not deployment_state._replicas.get()[0]._actor.stopped + # Before shutdown, `is_ready_for_shutdown()` should return False + assert not deployment_state_manager.is_ready_for_shutdown() + deployment_state_manager.shutdown() timer.advance(grace_period_s + 0.1) @@ -2541,6 +2544,9 @@ def test_shutdown(mock_deployment_state_manager, is_driver_deployment): check_counts(deployment_state, total=0) assert len(deployment_state_manager.get_deployment_statuses()) == 0 + # After all deployments shutdown, `is_ready_for_shutdown()` should return True + assert deployment_state_manager.is_ready_for_shutdown() + def test_resource_requirements_none(): """Ensure resource_requirements doesn't break if a requirement is None""" diff --git a/python/ray/serve/tests/test_endpoint_state.py b/python/ray/serve/tests/test_endpoint_state.py new file mode 100644 index 000000000000..23c68011a3f1 --- /dev/null +++ b/python/ray/serve/tests/test_endpoint_state.py @@ -0,0 +1,65 @@ +import sys +from typing import Any, Tuple +from unittest.mock import patch, Mock + +import pytest +from ray.serve._private.endpoint_state import EndpointState + + +class MockKVStore: + def __init__(self): + self.store = dict() + + def put(self, key: str, val: Any) -> bool: + if not isinstance(key, str): + raise TypeError("key must be a string, got: {}.".format(type(key))) + self.store[key] = val + return True + + def get(self, key: str) -> Any: + if not isinstance(key, str): + raise TypeError("key must be a string, got: {}.".format(type(key))) + return self.store.get(key, None) + + def delete(self, key: str) -> bool: + if not isinstance(key, str): + raise TypeError("key must be a string, got: {}.".format(type(key))) + + if key in self.store: + del self.store[key] + return True + + return False + + +@pytest.fixture +def mock_endpoint_state() -> Tuple[EndpointState, Mock]: + with patch("ray.serve._private.long_poll.LongPollHost") as mock_long_poll: + endpoint_state = EndpointState( + kv_store=MockKVStore(), + long_poll_host=mock_long_poll, + ) + yield endpoint_state + + +def test_is_ready_for_shutdown(mock_endpoint_state): + """Test `is_ready_for_shutdown()` returns the correct state. + + Before shutting down endpoint `is_ready_for_shutdown()` should return False. + After shutting down endpoint `is_ready_for_shutdown()` should return True. + """ + # Setup endpoint state with checkpoint + endpoint_state = mock_endpoint_state + endpoint_state._checkpoint() + + # Before shutdown is called, `is_ready_for_shutdown()` should return False + assert not endpoint_state.is_ready_for_shutdown() + + endpoint_state.shutdown() + + # After shutdown is called, `is_ready_for_shutdown()` should return True + assert endpoint_state.is_ready_for_shutdown() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_http_state.py b/python/ray/serve/tests/test_http_state.py index 3a07af1c9753..dbe56ec263c4 100644 --- a/python/ray/serve/tests/test_http_state.py +++ b/python/ray/serve/tests/test_http_state.py @@ -157,9 +157,7 @@ def test_node_selection(all_nodes, mock_get_all_node_ids): assert set(another_seed) != set(selected_nodes) -def test_http_state_update_restarts_unhealthy_proxies( - mock_get_all_node_ids, setup_controller -): +def test_http_state_update_restarts_unhealthy_proxies(mock_get_all_node_ids): """Test the update method in HTTPState would kill and restart unhealthy proxies. Set up a HTTPProxyState with UNHEALTHY status. Calls the update method on the @@ -199,7 +197,7 @@ def _update_state_and_check_proxy_status( assert new_proxy != old_proxy -def test_http_proxy_state_update_shutting_down(setup_controller): +def test_http_proxy_state_update_shutting_down(): """Test calling update method on HTTPProxyState when the proxy state is shutting down. @@ -218,7 +216,7 @@ def test_http_proxy_state_update_shutting_down(setup_controller): assert previous_status == current_status -def test_http_proxy_state_update_starting_ready_succeed(setup_controller): +def test_http_proxy_state_update_starting_ready_succeed(): """Test calling update method on HTTPProxyState when the proxy state is STARTING and when the ready call succeeded. @@ -238,7 +236,7 @@ def test_http_proxy_state_update_starting_ready_succeed(setup_controller): ) -def test_http_proxy_state_update_starting_ready_failed_once(setup_controller): +def test_http_proxy_state_update_starting_ready_failed_once(): """Test calling update method on HTTPProxyState when the proxy state is STARTING and when the ready call failed once and succeeded for the following call. @@ -279,7 +277,7 @@ async def check_health(self): ) -def test_http_proxy_state_update_starting_ready_always_fails(setup_controller): +def test_http_proxy_state_update_starting_ready_always_fails(): """Test calling update method on HTTPProxyState when the proxy state is STARTING and when the ready call is always failing. @@ -313,7 +311,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_READY_CHECK_TIMEOUT_S", 1) -def test_http_proxy_state_update_starting_ready_always_timeout(setup_controller): +def test_http_proxy_state_update_starting_ready_always_timeout(): """Test calling update method on HTTPProxyState when the proxy state is STARTING and when the ready call always timed out. @@ -344,7 +342,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_update_healthy_check_health_succeed(setup_controller): +def test_http_proxy_state_update_healthy_check_health_succeed(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the check_health call succeeded @@ -375,7 +373,7 @@ def test_http_proxy_state_update_healthy_check_health_succeed(setup_controller): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_update_healthy_check_health_failed_once(setup_controller): +def test_http_proxy_state_update_healthy_check_health_failed_once(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the check_health call failed once and succeeded for the following call. @@ -426,7 +424,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_update_healthy_check_health_always_fails(setup_controller): +def test_http_proxy_state_update_healthy_check_health_always_fails(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the check_health call is always failing. @@ -471,9 +469,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_TIMEOUT_S", 0.1) @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_check_health_always_timeout_timeout_eq_period( - setup_controller, -): +def test_http_proxy_state_check_health_always_timeout_timeout_eq_period(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the ready call always timed out and health check timeout and period equals. @@ -518,9 +514,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_TIMEOUT_S", 1) @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_check_health_always_timeout_timeout_greater_than_period( - setup_controller, -): +def test_http_proxy_state_check_health_always_timeout_timeout_greater_than_period(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the ready call always timed out and health check timeout greater than period. @@ -564,7 +558,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_update_unhealthy_check_health_succeed(setup_controller): +def test_http_proxy_state_update_unhealthy_check_health_succeed(): """Test calling update method on HTTPProxyState when the proxy state is UNHEALTHY and when the check_health call succeeded. @@ -682,6 +676,39 @@ def test_update_draining(mock_get_all_node_ids, setup_controller, all_nodes): ) +def test_is_ready_for_shutdown(mock_get_all_node_ids, all_nodes): + """Test `is_ready_for_shutdown()` returns True the correct state. + + Before `shutdown()` is called, `is_ready_for_shutdown()` should return false. After + `shutdown()` is called and all proxy actor are killed, `is_ready_for_shutdown()` + should return true. + """ + state = _make_http_state(HTTPOptions(location=DeploymentMode.EveryNode)) + + for node_id, node_ip_address in all_nodes: + state._proxy_states[node_id] = _create_http_proxy_state( + proxy_actor_class=HTTPProxyActor, + status=HTTPProxyStatus.HEALTHY, + node_id=node_id, + host="localhost", + port=8000, + root_path="/", + controller_name=SERVE_CONTROLLER_NAME, + node_ip_address=node_ip_address, + ) + + # Ensure before shutdown, state is not shutdown + assert not state.is_ready_for_shutdown() + + state.shutdown() + + # Ensure after shutdown, state is shutdown and all proxy states are shutdown + def check_is_ready_for_shutdown(): + return state.is_ready_for_shutdown() + + wait_for_condition(check_is_ready_for_shutdown) + + if __name__ == "__main__": import sys diff --git a/python/ray/serve/tests/test_standalone3.py b/python/ray/serve/tests/test_standalone3.py index 0fbfdf457ea4..8ee0e5c7e851 100644 --- a/python/ray/serve/tests/test_standalone3.py +++ b/python/ray/serve/tests/test_standalone3.py @@ -506,5 +506,105 @@ def _check(): serve.shutdown() +@pytest.mark.parametrize("wait_for_controller_shutdown", (True, False)) +def test_controller_shutdown_gracefully( + shutdown_ray, call_ray_stop_only, wait_for_controller_shutdown # noqa: F811 +): + """Test controller shutdown gracefully when calling `graceful_shutdown()`. + + Called `graceful_shutdown()` on the controller, so it will start shutdown and + eventually all actors will be in DEAD state. Test both cases whether to wait for + the controller shutdown or not should both resolve graceful shutdown. + """ + # Setup a cluster with 2 nodes + cluster = Cluster() + cluster.add_node() + cluster.add_node() + cluster.wait_for_nodes() + ray.init(address=cluster.address) + + # Deploy 2 replicas + @serve.deployment(num_replicas=2) + class HelloModel: + def __call__(self): + return "hello" + + model = HelloModel.bind() + serve.run(target=model) + + # Ensure total actors of 2 proxies, 1 controller, and 2 replicas + wait_for_condition(lambda: len(ray._private.state.actors()) == 5) + assert len(ray.nodes()) == 2 + + # Call `graceful_shutdown()` on the controller, so it will start shutdown. + client = get_global_client() + if wait_for_controller_shutdown: + # Waiting for controller shutdown will throw RayActorError when the controller + # killed itself. + with pytest.raises(ray.exceptions.RayActorError): + ray.get(client._controller.graceful_shutdown.remote(True)) + else: + ray.get(client._controller.graceful_shutdown.remote(False)) + + # Ensure the all resources are shutdown. + wait_for_condition( + lambda: all( + [actor["State"] == "DEAD" for actor in ray._private.state.actors().values()] + ) + ) + + # Clean up serve. + serve.shutdown() + + +def test_client_shutdown_gracefully_when_timeout( + shutdown_ray, call_ray_stop_only, caplog # noqa: F811 +): + """Test client shutdown gracefully when timeout. + + When the controller is taking longer than the timeout to shutdown, the client will + log timeout message and exit the process. The controller will continue to shutdown + everything gracefully. + """ + # Setup a cluster with 2 nodes + cluster = Cluster() + cluster.add_node() + cluster.add_node() + cluster.wait_for_nodes() + ray.init(address=cluster.address) + + # Deploy 2 replicas + @serve.deployment(num_replicas=2) + class HelloModel: + def __call__(self): + return "hello" + + model = HelloModel.bind() + serve.run(target=model) + + # Ensure total actors of 2 proxies, 1 controller, and 2 replicas + wait_for_condition(lambda: len(ray._private.state.actors()) == 5) + assert len(ray.nodes()) == 2 + + # Ensure client times out if the controller does not shutdown within timeout. + timeout_s = 0.0 + client = get_global_client() + client.shutdown(timeout_s=timeout_s) + assert ( + f"Controller failed to shut down within {timeout_s}s. " + f"Check controller logs for more details." in caplog.text + ) + + # Ensure the all resources are shutdown gracefully. + wait_for_condition( + lambda: all( + [actor["State"] == "DEAD" for actor in ray._private.state.actors().values()] + ), + ) + + # Clean up serve. + serve.shutdown() + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__]))