From 97f4964dce1ac737f59db39bc4fc5efe5c38940b Mon Sep 17 00:00:00 2001 From: shrekris-anyscale <92341594+shrekris-anyscale@users.noreply.github.com> Date: Wed, 2 Aug 2023 09:15:19 -0700 Subject: [PATCH] [Serve] Decrement `ray_serve_deployment_queued_queries` when client disconnects (#37965) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `ray_serve_deployment_queued_queries` metric tracks the number of queries that have yet to be assigned a replica. If a client disconnects before its query has been assigned a replica– but after the metric has counted their query– the query terminates, but the metric doesn't decrease. This change decrements `ray_serve_deployment_queued_queries` when a queued request is disconnected. Signed-off-by: Edward Oakes --- python/ray/serve/_private/router.py | 37 +++++++++------- python/ray/serve/tests/test_metrics.py | 61 +++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 17 deletions(-) diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 03c01fc988f49..b24fddcc412a6 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -929,20 +929,25 @@ async def assign_request( }, ) - query = Query( - args=list(request_args), - kwargs=request_kwargs, - metadata=request_meta, - ) - await query.resolve_async_tasks() - result = await self._replica_scheduler.assign_replica(query) - - self.num_queued_queries -= 1 - self.num_queued_queries_gauge.set( - self.num_queued_queries, - tags={ - "application": request_meta.app_name, - }, - ) + try: + query = Query( + args=list(request_args), + kwargs=request_kwargs, + metadata=request_meta, + ) + await query.resolve_async_tasks() + result = await self._replica_scheduler.assign_replica(query) - return result + return result + finally: + # If the query is disconnected before assignment, this coroutine + # gets cancelled by the caller and an asyncio.CancelledError is + # raised. The finally block ensures that num_queued_queries + # is correctly decremented in this case. + self.num_queued_queries -= 1 + self.num_queued_queries_gauge.set( + self.num_queued_queries, + tags={ + "application": request_meta.app_name, + }, + ) diff --git a/python/ray/serve/tests/test_metrics.py b/python/ray/serve/tests/test_metrics.py index b72b0b0318306..0f2e39abab47f 100644 --- a/python/ray/serve/tests/test_metrics.py +++ b/python/ray/serve/tests/test_metrics.py @@ -1,4 +1,6 @@ import os +from functools import partial +from multiprocessing import Pool from typing import List, Dict, DefaultDict import requests @@ -841,6 +843,63 @@ def verify_metrics(): ) +def test_queued_queries_disconnected(serve_start_shutdown): + """Check that queued_queries decrements when queued requests disconnect.""" + + signal = SignalActor.remote() + + @serve.deployment( + max_concurrent_queries=1, + graceful_shutdown_timeout_s=0.0001, + ) + async def hang_on_first_request(): + await signal.wait.remote() + + serve.run(hang_on_first_request.bind()) + + print("Deployed hang_on_first_request deployment.") + + def queue_size() -> float: + metrics = requests.get("http://127.0.0.1:9999").text + queue_size = -1 + for line in metrics.split("\n"): + if "ray_serve_deployment_queued_queries" in line: + queue_size = line.split(" ")[-1] + + return float(queue_size) + + def first_request_executing(request_future) -> bool: + try: + request_future.get(timeout=0.1) + except Exception: + return ray.get(signal.cur_num_waiters.remote()) == 1 + + url = "http://localhost:8000/" + + pool = Pool() + + # Make a request to block the deployment from accepting other requests + fut = pool.apply_async(partial(requests.get, url)) + wait_for_condition(lambda: first_request_executing(fut), timeout=5) + print("Executed first request.") + + num_requests = 5 + for _ in range(num_requests): + pool.apply_async(partial(requests.get, url)) + print(f"Executed {num_requests} more requests.") + + # First request should be processing. All others should be queued. + wait_for_condition(lambda: queue_size() == num_requests, timeout=15) + print("ray_serve_deployment_queued_queries updated successfully.") + + # Disconnect all requests by terminating the process pool. + pool.terminate() + print("Terminated all requests.") + + wait_for_condition(lambda: queue_size() == 0, timeout=15) + print("ray_serve_deployment_queued_queries updated successfully.") + + def test_actor_summary(serve_instance): @serve.deployment def f(): @@ -855,7 +914,7 @@ def f(): def get_metric_dictionaries(name: str, timeout: float = 20) -> List[Dict]: - """Gets a list of metric's dictionaries from metrics' text output. + """Gets a list of metric's tags from metrics' text output. Return: Example: