diff --git a/python/ray/serve/_private/proxy.py b/python/ray/serve/_private/proxy.py index 7b850b6ccf1c1..54f6908072c68 100644 --- a/python/ray/serve/_private/proxy.py +++ b/python/ray/serve/_private/proxy.py @@ -1009,18 +1009,17 @@ async def send_request_to_replica( # the trailers message has been sent. if not asgi_message.get("more_trailers", False): response_generator.stop_checking_for_disconnect() - elif asgi_message["type"] == "websocket.disconnect": + elif asgi_message["type"] in [ + "websocket.close", + "websocket.disconnect", + ]: status_code = str(asgi_message["code"]) - - # Check based on standard WebSocket status codes - if status_code in ["1000", "1001"]: - # Normal closure or going away, no error - is_error = False - else: - # Other 1xxx codes are specified as errors - is_error = status_code.startswith("1") - - status = ResponseStatus(code=status_code, is_error=is_error) + status = ResponseStatus( + code=status_code, + # All status codes are considered errors aside from: + # 1000 (CLOSE_NORMAL), 1001 (CLOSE_GOING_AWAY). + is_error=status_code not in ["1000", "1001"], + ) response_generator.stop_checking_for_disconnect() yield asgi_message diff --git a/python/ray/serve/tests/test_metrics.py b/python/ray/serve/tests/test_metrics.py index 6b3e674e79077..6f64666a96ba7 100644 --- a/python/ray/serve/tests/test_metrics.py +++ b/python/ray/serve/tests/test_metrics.py @@ -6,7 +6,11 @@ import grpc import pytest import requests -from fastapi import FastAPI +from fastapi import FastAPI, WebSocket +from starlette.requests import Request +from starlette.responses import PlainTextResponse +from websockets.exceptions import ConnectionClosed +from websockets.sync.client import connect import ray import ray.util.state as state_api @@ -583,6 +587,161 @@ def f(*args): print("serve_grpc_request_latency_ms_sum working as expected.") +def test_proxy_metrics_http_status_code_is_error(serve_start_shutdown): + """Verify that 2xx status codes aren't errors, others are.""" + + def check_request_count_metrics( + expected_error_count: int, + expected_success_count: int, + ): + resp = requests.get("http://127.0.0.1:9999").text + error_count = 0 + success_count = 0 + for line in resp.split("\n"): + if line.startswith("ray_serve_num_http_error_requests_total"): + error_count += int(float(line.split(" ")[-1])) + if line.startswith("ray_serve_num_http_requests_total"): + success_count += int(float(line.split(" ")[-1])) + + assert error_count == expected_error_count + assert success_count == expected_success_count + return True + + @serve.deployment + async def return_status_code(request: Request): + code = int((await request.body()).decode("utf-8")) + return PlainTextResponse("", status_code=code) + + serve.run(return_status_code.bind()) + + # 200 is not an error. + r = requests.get("http://127.0.0.1:8000/", data=b"200") + assert r.status_code == 200 + wait_for_condition( + check_request_count_metrics, + expected_error_count=0, + expected_success_count=1, + ) + + # 2xx is not an error. + r = requests.get("http://127.0.0.1:8000/", data=b"250") + assert r.status_code == 250 + wait_for_condition( + check_request_count_metrics, + expected_error_count=0, + expected_success_count=2, + ) + + # 3xx is an error. + r = requests.get("http://127.0.0.1:8000/", data=b"300") + assert r.status_code == 300 + wait_for_condition( + check_request_count_metrics, + expected_error_count=1, + expected_success_count=3, + ) + + # 4xx is an error. + r = requests.get("http://127.0.0.1:8000/", data=b"400") + assert r.status_code == 400 + wait_for_condition( + check_request_count_metrics, + expected_error_count=2, + expected_success_count=4, + ) + + # 5xx is an error. + r = requests.get("http://127.0.0.1:8000/", data=b"500") + assert r.status_code == 500 + wait_for_condition( + check_request_count_metrics, + expected_error_count=3, + expected_success_count=5, + ) + + +def test_proxy_metrics_websocket_status_code_is_error(serve_start_shutdown): + """Verify that status codes aisde from 1000 or 1001 are errors.""" + + def check_request_count_metrics( + expected_error_count: int, + expected_success_count: int, + ): + resp = requests.get("http://127.0.0.1:9999").text + error_count = 0 + success_count = 0 + for line in resp.split("\n"): + if line.startswith("ray_serve_num_http_error_requests_total"): + error_count += int(float(line.split(" ")[-1])) + if line.startswith("ray_serve_num_http_requests_total"): + success_count += int(float(line.split(" ")[-1])) + + assert error_count == expected_error_count + assert success_count == expected_success_count + return True + + fastapi_app = FastAPI() + + @serve.deployment + @serve.ingress(fastapi_app) + class WebSocketServer: + @fastapi_app.websocket("/") + async def accept_then_close(self, ws: WebSocket): + await ws.accept() + code = int(await ws.receive_text()) + await ws.close(code=code) + + serve.run(WebSocketServer.bind()) + + # Regular disconnect (1000) is not an error. + with connect("ws://localhost:8000/") as ws: + with pytest.raises(ConnectionClosed): + ws.send("1000") + ws.recv() + + wait_for_condition( + check_request_count_metrics, + expected_error_count=0, + expected_success_count=1, + ) + + # Goaway disconnect (1001) is not an error. + with connect("ws://localhost:8000/") as ws: + with pytest.raises(ConnectionClosed): + ws.send("1001") + ws.recv() + + wait_for_condition( + check_request_count_metrics, + expected_error_count=0, + expected_success_count=2, + ) + + # Other codes are errors. + with connect("ws://localhost:8000/") as ws: + with pytest.raises(ConnectionClosed): + ws.send("1011") + ws.recv() + + wait_for_condition( + check_request_count_metrics, + expected_error_count=1, + expected_success_count=3, + ) + + # Other codes are errors. + with connect("ws://localhost:8000/") as ws: + with pytest.raises(ConnectionClosed): + ws.send("3000") + ws.recv() + + wait_for_condition( + check_request_count_metrics, + expected_error_count=2, + expected_success_count=4, + ) + + def test_replica_metrics_fields(serve_start_shutdown): """Test replica metrics fields"""