diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 4b250657a5562..45f2cfdc75b8b 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -681,3 +681,11 @@ py_test( tags = ["exclusive", "team:serve"], deps = [":serve_lib"], ) + +py_test( + name = "test_endpoint_state", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive", "team:serve"], + deps = [":serve_lib"], +) diff --git a/python/ray/serve/_private/application_state.py b/python/ray/serve/_private/application_state.py index 89f3d5ef7f65e..815cda6631fd2 100644 --- a/python/ray/serve/_private/application_state.py +++ b/python/ray/serve/_private/application_state.py @@ -177,9 +177,20 @@ def _delete_deployment(self, name): def delete(self): """Delete the application""" - logger.info(f"Deleting application '{self._name}'") + logger.info( + f"Deleting application '{self._name}'", + extra={"log_to_stderr": False}, + ) self._set_target_state(deleting=True) + def is_deleted(self) -> bool: + """Check whether the application is already deleted. + + For an application to be considered deleted, the target state has to be set to + deleting and all deployments have to be deleted. + """ + return self._target_state.deleting and len(self._get_live_deployments()) == 0 + def apply_deployment_info( self, deployment_name: str, deployment_info: DeploymentInfo ) -> None: @@ -412,7 +423,7 @@ def update(self) -> bool: # Check if app is ready to be deleted if self._target_state.deleting: - return len(self._get_live_deployments()) == 0 + return self.is_deleted() return False def get_checkpoint_data(self) -> ApplicationTargetState: @@ -639,6 +650,16 @@ def shutdown(self) -> None: for app_state in self._application_states.values(): app_state.delete() + def is_ready_for_shutdown(self) -> bool: + """Return whether all applications have shut down. + + Iterate through all application states and check if all their applications + are deleted. + """ + return all( + app_state.is_deleted() for app_state in self._application_states.values() + ) + def _save_checkpoint_func( self, *, writeahead_checkpoints: Optional[Dict[str, ApplicationTargetState]] ) -> None: diff --git a/python/ray/serve/_private/client.py b/python/ray/serve/_private/client.py index 7c9a61358f27e..31acdfcb4cb9b 100644 --- a/python/ray/serve/_private/client.py +++ b/python/ray/serve/_private/client.py @@ -20,7 +20,6 @@ CLIENT_POLLING_INTERVAL_S, CLIENT_CHECK_CREATION_POLLING_INTERVAL_S, MAX_CACHED_HANDLES, - SERVE_NAMESPACE, SERVE_DEFAULT_APP_NAME, ) from ray.serve._private.deploy_utils import get_deploy_args @@ -95,7 +94,7 @@ def __del__(self): def __reduce__(self): raise RayServeException(("Ray Serve client cannot be serialized.")) - def shutdown(self) -> None: + def shutdown(self, timeout_s: float = 30.0) -> None: """Completely shut down the connected Serve instance. Shuts down all processes and deletes all state associated with the @@ -107,53 +106,17 @@ def shutdown(self) -> None: del self.handle_cache[k] if ray.is_initialized() and not self._shutdown: - ray.get(self._controller.shutdown.remote()) - self._wait_for_deployments_shutdown() - - ray.kill(self._controller, no_restart=True) - - # Wait for the named actor entry gets removed as well. - started = time.time() - while True: - try: - ray.get_actor(self._controller_name, namespace=SERVE_NAMESPACE) - if time.time() - started > 5: - logger.warning( - "Waited 5s for Serve to shutdown gracefully but " - "the controller is still not cleaned up. " - "You can ignore this warning if you are shutting " - "down the Ray cluster." - ) - break - except ValueError: # actor name is removed - break - - self._shutdown = True - - def _wait_for_deployments_shutdown(self, timeout_s: int = 60): - """Waits for all deployments to be shut down and deleted. - - Raises TimeoutError if this doesn't happen before timeout_s. - """ - start = time.time() - while time.time() - start < timeout_s: - deployment_statuses = self.get_all_deployment_statuses() - if len(deployment_statuses) == 0: - break - else: - logger.debug( - f"Waiting for shutdown, {len(deployment_statuses)} " - "deployments still alive." + try: + ray.get(self._controller.graceful_shutdown.remote(), timeout=timeout_s) + except ray.exceptions.RayActorError: + # Controller has been shut down. + pass + except TimeoutError: + logger.warning( + f"Controller failed to shut down within {timeout_s}s. " + "Check controller logs for more details." ) - time.sleep(CLIENT_POLLING_INTERVAL_S) - else: - live_names = [ - deployment_status.name for deployment_status in deployment_statuses - ] - raise TimeoutError( - f"Shutdown didn't complete after {timeout_s}s. " - f"Deployments still alive: {live_names}." - ) + self._shutdown = True def _wait_for_deployment_healthy(self, name: str, timeout_s: int = -1): """Waits for the named deployment to enter "HEALTHY" status. diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 603c30957f611..cc030a52df69c 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1279,7 +1279,10 @@ def _set_target_state_deleting(self) -> None: self._curr_status_info = DeploymentStatusInfo( self._name, DeploymentStatus.UPDATING ) - logger.info(f"Deleting deployment {self._name}.") + logger.info( + f"Deleting deployment {self._name}.", + extra={"log_to_stderr": False}, + ) def _set_target_state(self, target_info: DeploymentInfo) -> None: """Set the target state for the deployment to the provided info.""" @@ -2284,6 +2287,16 @@ def shutdown(self): # TODO(jiaodong): Need to add some logic to prevent new replicas # from being created once shutdown signal is sent. + def is_ready_for_shutdown(self) -> bool: + """Return whether all deployments are shutdown. + + Check there are no deployment states and no checkpoints. + """ + return ( + len(self._deployment_states) == 0 + and self._kv_store.get(CHECKPOINT_KEY) is None + ) + def _save_checkpoint_func( self, *, writeahead_checkpoints: Optional[Dict[str, Tuple]] ) -> None: diff --git a/python/ray/serve/_private/endpoint_state.py b/python/ray/serve/_private/endpoint_state.py index e6662892173f1..949e2ad08537f 100644 --- a/python/ray/serve/_private/endpoint_state.py +++ b/python/ray/serve/_private/endpoint_state.py @@ -34,6 +34,14 @@ def __init__(self, kv_store: KVStoreBase, long_poll_host: LongPollHost): def shutdown(self): self._kv_store.delete(CHECKPOINT_KEY) + def is_ready_for_shutdown(self) -> bool: + """Returns whether the endpoint checkpoint has been deleted. + + Get the endpoint checkpoint from the kv store. If it is None, then it has been + deleted. + """ + return self._kv_store.get(CHECKPOINT_KEY) is None + def _checkpoint(self): self._kv_store.put(CHECKPOINT_KEY, cloudpickle.dumps(self._endpoints)) diff --git a/python/ray/serve/_private/http_state.py b/python/ray/serve/_private/http_state.py index 45cd98003e149..71a4ea97e8ae0 100644 --- a/python/ray/serve/_private/http_state.py +++ b/python/ray/serve/_private/http_state.py @@ -230,6 +230,27 @@ def shutdown(self): self._shutting_down = True ray.kill(self.actor_handle, no_restart=True) + def is_ready_for_shutdown(self) -> bool: + """Return whether the HTTP proxy actor is shutdown. + + For an HTTP proxy actor to be considered shutdown, it must be marked as + _shutting_down and the actor must be dead. If the actor is dead, the health + check will return RayActorError. + """ + if not self._shutting_down: + return False + + try: + ray.get(self._actor_handle.check_health.remote(), timeout=0.001) + except ray.exceptions.RayActorError: + # The actor is dead, so it's ready for shutdown. + return True + except ray.exceptions.GetTimeoutError: + # The actor is still alive, so it's not ready for shutdown. + return False + + return False + class HTTPState: """Manages all state for HTTP proxies in the system. @@ -269,6 +290,17 @@ def shutdown(self) -> None: for proxy_state in self._proxy_states.values(): proxy_state.shutdown() + def is_ready_for_shutdown(self) -> bool: + """Return whether all proxies are shutdown. + + Iterate through all proxy states and check if all their proxy actors + are shutdown. + """ + return all( + proxy_state.is_ready_for_shutdown() + for proxy_state in self._proxy_states.values() + ) + def get_config(self): return self._config diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 1d1b224470c8c..7b718d0802113 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -187,6 +187,9 @@ async def __init__( worker_id=ray.get_runtime_context().get_worker_id(), log_file_path=get_component_logger_file_path(), ) + self._shutting_down = False + self._shutdown = asyncio.Event() + self._shutdown_start_time = None run_background_task(self.run_control_loop()) @@ -307,6 +310,12 @@ async def run_control_loop(self) -> None: recovering_timeout = RECOVERING_LONG_POLL_BROADCAST_TIMEOUT_S start_time = time.time() while True: + if self._shutting_down: + try: + self.shutdown() + except Exception: + logger.exception("Exception during shutdown.") + if ( not self.done_recovering_event.is_set() and time.time() - start_time > recovering_timeout @@ -439,8 +448,29 @@ def get_root_url(self): ) return http_config.root_url + def config_checkpoint_deleted(self) -> bool: + """Returns whether the config checkpoint has been deleted. + + Get the config checkpoint from the kv store. If it is None, then it has been + deleted. + """ + return self.kv_store.get(CONFIG_CHECKPOINT_KEY) is None + def shutdown(self): - """Shuts down the serve instance completely.""" + """Shuts down the serve instance completely. + + This method will only be triggered when `self._shutting_down` is true. It + deletes the kv store for config checkpoints, sets application state to deleting, + delete all deployments, and shuts down all HTTP proxies. Once all these + resources are released, it then kills the controller actor. + """ + if not self._shutting_down: + return + + if self._shutdown_start_time is None: + self._shutdown_start_time = time.time() + + logger.info("Controller shutdown started!", extra={"log_to_stderr": False}) self.kv_store.delete(CONFIG_CHECKPOINT_KEY) self.application_state_manager.shutdown() self.deployment_state_manager.shutdown() @@ -448,6 +478,54 @@ def shutdown(self): if self.http_state: self.http_state.shutdown() + config_checkpoint_deleted = self.config_checkpoint_deleted() + application_is_shutdown = self.application_state_manager.is_ready_for_shutdown() + deployment_is_shutdown = self.deployment_state_manager.is_ready_for_shutdown() + endpoint_is_shutdown = self.endpoint_state.is_ready_for_shutdown() + http_state_is_shutdown = ( + self.http_state is None or self.http_state.is_ready_for_shutdown() + ) + if ( + config_checkpoint_deleted + and application_is_shutdown + and deployment_is_shutdown + and endpoint_is_shutdown + and http_state_is_shutdown + ): + logger.warning( + "All resources have shut down, shutting down controller!", + extra={"log_to_stderr": False}, + ) + _controller_actor = ray.get_runtime_context().current_actor + self._shutdown.set() + ray.kill(_controller_actor, no_restart=True) + elif time.time() - self._shutdown_start_time > 10: + if not config_checkpoint_deleted: + logger.warning( + f"{CONFIG_CHECKPOINT_KEY} not yet deleted", + extra={"log_to_stderr": False}, + ) + if not application_is_shutdown: + logger.warning( + "application not yet shutdown", + extra={"log_to_stderr": False}, + ) + if not deployment_is_shutdown: + logger.warning( + "deployment not yet shutdown", + extra={"log_to_stderr": False}, + ) + if not endpoint_is_shutdown: + logger.warning( + "endpoint not yet shutdown", + extra={"log_to_stderr": False}, + ) + if not http_state_is_shutdown: + logger.warning( + "http_state not yet shutdown", + extra={"log_to_stderr": False}, + ) + def deploy( self, name: str, @@ -835,6 +913,20 @@ def record_multiplexed_replica_info(self, info: MultiplexedReplicaInfo): """ self.deployment_state_manager.record_multiplexed_replica_info(info) + async def graceful_shutdown(self, wait: bool = True): + """Set the shutting down flag on controller to signal shutdown in + run_control_loop(). + + This is used to signal to the controller that it should proceed with shutdown + process, so it can shut down gracefully. It also waits until the shutdown + event is triggered if wait is true. + """ + self._shutting_down = True + if not wait: + return + + await self._shutdown.wait() + @ray.remote(num_cpus=0, max_calls=1) def deploy_serve_application( @@ -943,7 +1035,7 @@ def __init__( http_proxy_port: int = 8000, ): try: - self._controller = ray.get_actor(controller_name, namespace="serve") + self._controller = ray.get_actor(controller_name, namespace=SERVE_NAMESPACE) except ValueError: self._controller = None if self._controller is None: @@ -956,7 +1048,7 @@ def __init__( max_restarts=-1, max_task_retries=-1, resources={HEAD_NODE_RESOURCE_NAME: 0.001}, - namespace="serve", + namespace=SERVE_NAMESPACE, max_concurrency=CONTROLLER_MAX_CONCURRENCY, ).remote( controller_name, diff --git a/python/ray/serve/tests/test_application_state.py b/python/ray/serve/tests/test_application_state.py index 0f7afc16098cb..8bdb18e414615 100644 --- a/python/ray/serve/tests/test_application_state.py +++ b/python/ray/serve/tests/test_application_state.py @@ -658,5 +658,51 @@ def test_recover_during_update(mocked_application_state_manager): assert app_state.status == ApplicationStatus.RUNNING +def test_is_ready_for_shutdown(mocked_application_state_manager): + """Test `is_ready_for_shutdown()` returns the correct state. + + When shutting down applications before deployments are deleted, application state + `is_deleted()` should return False and `is_ready_for_shutdown()` should return + False. When shutting down applications after deployments are deleted, application + state `is_deleted()` should return True and `is_ready_for_shutdown()` should return + True. + """ + ( + app_state_manager, + deployment_state_manager, + kv_store, + ) = mocked_application_state_manager + app_name = "test_app" + deployment_name = "d1" + + # DEPLOY application with deployment "d1" + params = deployment_params(deployment_name) + app_state_manager.apply_deployment_args(app_name, [params]) + app_state = app_state_manager._application_states[app_name] + assert app_state.status == ApplicationStatus.DEPLOYING + + # Once deployment is healthy, app should be running + app_state_manager.update() + assert deployment_state_manager.get_deployment(deployment_name) + deployment_state_manager.set_deployment_healthy(deployment_name) + app_state_manager.update() + assert app_state.status == ApplicationStatus.RUNNING + + # When shutting down applications before deployments are deleted, application state + # `is_deleted()` should return False and `is_ready_for_shutdown()` should return + # False + app_state_manager.shutdown() + assert not app_state.is_deleted() + assert not app_state_manager.is_ready_for_shutdown() + + # When shutting down applications after deployments are deleted, application state + # `is_deleted()` should return True and `is_ready_for_shutdown()` should return True + deployment_state_manager.delete_deployment(deployment_name) + deployment_state_manager.set_deployment_deleted(deployment_name) + app_state_manager.update() + assert app_state.is_deleted() + assert app_state_manager.is_ready_for_shutdown() + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_cli.py b/python/ray/serve/tests/test_cli.py index 21a3a633592e7..1526588d367f1 100644 --- a/python/ray/serve/tests/test_cli.py +++ b/python/ray/serve/tests/test_cli.py @@ -828,8 +828,9 @@ def parrot(request): parrot_node = parrot.bind() +@pytest.mark.parametrize("number_of_kill_signals", (1, 2)) @pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.") -def test_run_application(ray_start_stop): +def test_run_application(ray_start_stop, number_of_kill_signals): """Deploys valid config file and import path via `serve run`.""" # Deploy via config file @@ -849,7 +850,8 @@ def test_run_application(ray_start_stop): ) print("Run successful! Deployments are live and reachable over HTTP. Killing run.") - p.send_signal(signal.SIGINT) # Equivalent to ctrl-C + for _ in range(number_of_kill_signals): + p.send_signal(signal.SIGINT) # Equivalent to ctrl-C p.wait() with pytest.raises(requests.exceptions.ConnectionError): requests.post("http://localhost:8000/", json=["ADD", 0]).json() diff --git a/python/ray/serve/tests/test_deployment_state.py b/python/ray/serve/tests/test_deployment_state.py index 97bddee4577c4..ef3dd6ebdb237 100644 --- a/python/ray/serve/tests/test_deployment_state.py +++ b/python/ray/serve/tests/test_deployment_state.py @@ -2525,6 +2525,9 @@ def test_shutdown(mock_deployment_state_manager, is_driver_deployment): # Test shutdown flow assert not deployment_state._replicas.get()[0]._actor.stopped + # Before shutdown, `is_ready_for_shutdown()` should return False + assert not deployment_state_manager.is_ready_for_shutdown() + deployment_state_manager.shutdown() timer.advance(grace_period_s + 0.1) @@ -2541,6 +2544,9 @@ def test_shutdown(mock_deployment_state_manager, is_driver_deployment): check_counts(deployment_state, total=0) assert len(deployment_state_manager.get_deployment_statuses()) == 0 + # After all deployments shutdown, `is_ready_for_shutdown()` should return True + assert deployment_state_manager.is_ready_for_shutdown() + def test_resource_requirements_none(): """Ensure resource_requirements doesn't break if a requirement is None""" diff --git a/python/ray/serve/tests/test_endpoint_state.py b/python/ray/serve/tests/test_endpoint_state.py new file mode 100644 index 0000000000000..23c68011a3f12 --- /dev/null +++ b/python/ray/serve/tests/test_endpoint_state.py @@ -0,0 +1,65 @@ +import sys +from typing import Any, Tuple +from unittest.mock import patch, Mock + +import pytest +from ray.serve._private.endpoint_state import EndpointState + + +class MockKVStore: + def __init__(self): + self.store = dict() + + def put(self, key: str, val: Any) -> bool: + if not isinstance(key, str): + raise TypeError("key must be a string, got: {}.".format(type(key))) + self.store[key] = val + return True + + def get(self, key: str) -> Any: + if not isinstance(key, str): + raise TypeError("key must be a string, got: {}.".format(type(key))) + return self.store.get(key, None) + + def delete(self, key: str) -> bool: + if not isinstance(key, str): + raise TypeError("key must be a string, got: {}.".format(type(key))) + + if key in self.store: + del self.store[key] + return True + + return False + + +@pytest.fixture +def mock_endpoint_state() -> Tuple[EndpointState, Mock]: + with patch("ray.serve._private.long_poll.LongPollHost") as mock_long_poll: + endpoint_state = EndpointState( + kv_store=MockKVStore(), + long_poll_host=mock_long_poll, + ) + yield endpoint_state + + +def test_is_ready_for_shutdown(mock_endpoint_state): + """Test `is_ready_for_shutdown()` returns the correct state. + + Before shutting down endpoint `is_ready_for_shutdown()` should return False. + After shutting down endpoint `is_ready_for_shutdown()` should return True. + """ + # Setup endpoint state with checkpoint + endpoint_state = mock_endpoint_state + endpoint_state._checkpoint() + + # Before shutdown is called, `is_ready_for_shutdown()` should return False + assert not endpoint_state.is_ready_for_shutdown() + + endpoint_state.shutdown() + + # After shutdown is called, `is_ready_for_shutdown()` should return True + assert endpoint_state.is_ready_for_shutdown() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_http_state.py b/python/ray/serve/tests/test_http_state.py index 3a07af1c97534..dbe56ec263c4a 100644 --- a/python/ray/serve/tests/test_http_state.py +++ b/python/ray/serve/tests/test_http_state.py @@ -157,9 +157,7 @@ def test_node_selection(all_nodes, mock_get_all_node_ids): assert set(another_seed) != set(selected_nodes) -def test_http_state_update_restarts_unhealthy_proxies( - mock_get_all_node_ids, setup_controller -): +def test_http_state_update_restarts_unhealthy_proxies(mock_get_all_node_ids): """Test the update method in HTTPState would kill and restart unhealthy proxies. Set up a HTTPProxyState with UNHEALTHY status. Calls the update method on the @@ -199,7 +197,7 @@ def _update_state_and_check_proxy_status( assert new_proxy != old_proxy -def test_http_proxy_state_update_shutting_down(setup_controller): +def test_http_proxy_state_update_shutting_down(): """Test calling update method on HTTPProxyState when the proxy state is shutting down. @@ -218,7 +216,7 @@ def test_http_proxy_state_update_shutting_down(setup_controller): assert previous_status == current_status -def test_http_proxy_state_update_starting_ready_succeed(setup_controller): +def test_http_proxy_state_update_starting_ready_succeed(): """Test calling update method on HTTPProxyState when the proxy state is STARTING and when the ready call succeeded. @@ -238,7 +236,7 @@ def test_http_proxy_state_update_starting_ready_succeed(setup_controller): ) -def test_http_proxy_state_update_starting_ready_failed_once(setup_controller): +def test_http_proxy_state_update_starting_ready_failed_once(): """Test calling update method on HTTPProxyState when the proxy state is STARTING and when the ready call failed once and succeeded for the following call. @@ -279,7 +277,7 @@ async def check_health(self): ) -def test_http_proxy_state_update_starting_ready_always_fails(setup_controller): +def test_http_proxy_state_update_starting_ready_always_fails(): """Test calling update method on HTTPProxyState when the proxy state is STARTING and when the ready call is always failing. @@ -313,7 +311,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_READY_CHECK_TIMEOUT_S", 1) -def test_http_proxy_state_update_starting_ready_always_timeout(setup_controller): +def test_http_proxy_state_update_starting_ready_always_timeout(): """Test calling update method on HTTPProxyState when the proxy state is STARTING and when the ready call always timed out. @@ -344,7 +342,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_update_healthy_check_health_succeed(setup_controller): +def test_http_proxy_state_update_healthy_check_health_succeed(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the check_health call succeeded @@ -375,7 +373,7 @@ def test_http_proxy_state_update_healthy_check_health_succeed(setup_controller): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_update_healthy_check_health_failed_once(setup_controller): +def test_http_proxy_state_update_healthy_check_health_failed_once(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the check_health call failed once and succeeded for the following call. @@ -426,7 +424,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_update_healthy_check_health_always_fails(setup_controller): +def test_http_proxy_state_update_healthy_check_health_always_fails(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the check_health call is always failing. @@ -471,9 +469,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_TIMEOUT_S", 0.1) @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_check_health_always_timeout_timeout_eq_period( - setup_controller, -): +def test_http_proxy_state_check_health_always_timeout_timeout_eq_period(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the ready call always timed out and health check timeout and period equals. @@ -518,9 +514,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_TIMEOUT_S", 1) @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_check_health_always_timeout_timeout_greater_than_period( - setup_controller, -): +def test_http_proxy_state_check_health_always_timeout_timeout_greater_than_period(): """Test calling update method on HTTPProxyState when the proxy state is HEALTHY and when the ready call always timed out and health check timeout greater than period. @@ -564,7 +558,7 @@ async def check_health(self): @patch("ray.serve._private.http_state.PROXY_HEALTH_CHECK_PERIOD_S", 0.1) -def test_http_proxy_state_update_unhealthy_check_health_succeed(setup_controller): +def test_http_proxy_state_update_unhealthy_check_health_succeed(): """Test calling update method on HTTPProxyState when the proxy state is UNHEALTHY and when the check_health call succeeded. @@ -682,6 +676,39 @@ def test_update_draining(mock_get_all_node_ids, setup_controller, all_nodes): ) +def test_is_ready_for_shutdown(mock_get_all_node_ids, all_nodes): + """Test `is_ready_for_shutdown()` returns True the correct state. + + Before `shutdown()` is called, `is_ready_for_shutdown()` should return false. After + `shutdown()` is called and all proxy actor are killed, `is_ready_for_shutdown()` + should return true. + """ + state = _make_http_state(HTTPOptions(location=DeploymentMode.EveryNode)) + + for node_id, node_ip_address in all_nodes: + state._proxy_states[node_id] = _create_http_proxy_state( + proxy_actor_class=HTTPProxyActor, + status=HTTPProxyStatus.HEALTHY, + node_id=node_id, + host="localhost", + port=8000, + root_path="/", + controller_name=SERVE_CONTROLLER_NAME, + node_ip_address=node_ip_address, + ) + + # Ensure before shutdown, state is not shutdown + assert not state.is_ready_for_shutdown() + + state.shutdown() + + # Ensure after shutdown, state is shutdown and all proxy states are shutdown + def check_is_ready_for_shutdown(): + return state.is_ready_for_shutdown() + + wait_for_condition(check_is_ready_for_shutdown) + + if __name__ == "__main__": import sys diff --git a/python/ray/serve/tests/test_standalone3.py b/python/ray/serve/tests/test_standalone3.py index 0fbfdf457ea40..8ee0e5c7e8510 100644 --- a/python/ray/serve/tests/test_standalone3.py +++ b/python/ray/serve/tests/test_standalone3.py @@ -506,5 +506,105 @@ def _check(): serve.shutdown() +@pytest.mark.parametrize("wait_for_controller_shutdown", (True, False)) +def test_controller_shutdown_gracefully( + shutdown_ray, call_ray_stop_only, wait_for_controller_shutdown # noqa: F811 +): + """Test controller shutdown gracefully when calling `graceful_shutdown()`. + + Called `graceful_shutdown()` on the controller, so it will start shutdown and + eventually all actors will be in DEAD state. Test both cases whether to wait for + the controller shutdown or not should both resolve graceful shutdown. + """ + # Setup a cluster with 2 nodes + cluster = Cluster() + cluster.add_node() + cluster.add_node() + cluster.wait_for_nodes() + ray.init(address=cluster.address) + + # Deploy 2 replicas + @serve.deployment(num_replicas=2) + class HelloModel: + def __call__(self): + return "hello" + + model = HelloModel.bind() + serve.run(target=model) + + # Ensure total actors of 2 proxies, 1 controller, and 2 replicas + wait_for_condition(lambda: len(ray._private.state.actors()) == 5) + assert len(ray.nodes()) == 2 + + # Call `graceful_shutdown()` on the controller, so it will start shutdown. + client = get_global_client() + if wait_for_controller_shutdown: + # Waiting for controller shutdown will throw RayActorError when the controller + # killed itself. + with pytest.raises(ray.exceptions.RayActorError): + ray.get(client._controller.graceful_shutdown.remote(True)) + else: + ray.get(client._controller.graceful_shutdown.remote(False)) + + # Ensure the all resources are shutdown. + wait_for_condition( + lambda: all( + [actor["State"] == "DEAD" for actor in ray._private.state.actors().values()] + ) + ) + + # Clean up serve. + serve.shutdown() + + +def test_client_shutdown_gracefully_when_timeout( + shutdown_ray, call_ray_stop_only, caplog # noqa: F811 +): + """Test client shutdown gracefully when timeout. + + When the controller is taking longer than the timeout to shutdown, the client will + log timeout message and exit the process. The controller will continue to shutdown + everything gracefully. + """ + # Setup a cluster with 2 nodes + cluster = Cluster() + cluster.add_node() + cluster.add_node() + cluster.wait_for_nodes() + ray.init(address=cluster.address) + + # Deploy 2 replicas + @serve.deployment(num_replicas=2) + class HelloModel: + def __call__(self): + return "hello" + + model = HelloModel.bind() + serve.run(target=model) + + # Ensure total actors of 2 proxies, 1 controller, and 2 replicas + wait_for_condition(lambda: len(ray._private.state.actors()) == 5) + assert len(ray.nodes()) == 2 + + # Ensure client times out if the controller does not shutdown within timeout. + timeout_s = 0.0 + client = get_global_client() + client.shutdown(timeout_s=timeout_s) + assert ( + f"Controller failed to shut down within {timeout_s}s. " + f"Check controller logs for more details." in caplog.text + ) + + # Ensure the all resources are shutdown gracefully. + wait_for_condition( + lambda: all( + [actor["State"] == "DEAD" for actor in ray._private.state.actors().values()] + ), + ) + + # Clean up serve. + serve.shutdown() + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__]))