Skip to content

Commit

Permalink
[serve] Add experimental support for StreamingResponse using `RayOb…
Browse files Browse the repository at this point in the history
…jectRefGenerator` (#35720)

Adds experimental support for using `StreamingResponse`s to stream intermediate results back to the client. This is currently gated behind a feature flag (must set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1`.

This is implemented by using the Ray `ObjectRefStreamingGenerator` interface. When the feature flag is on, the HTTP proxy will use `num_returns="streaming"` for _all_ calls to downstream replicas. The replica code has been modified to incrementally yield raw ASGI messages back to the HTTP proxy.

Known limitations & follow-ups (to be addressed before a non-experimental release):

- Minor performance regression due to an extra RPC from streaming protocol (see the microbenchmark results posted on #35468). Most of this should be able to be optimized away before turning this on by default.
- Streaming is not yet possible using the `ServeHandle` interface: #35777
- `max_concurrent_queries` is not respected by the HTTP proxy when streaming is enabled; we do simple round-robin instead: #35778
- The timeout set in the HTTP proxy does not apply to streaming responses: #35779
  • Loading branch information
edoakes authored May 26, 2023
1 parent 08fcf79 commit 6ae6920
Show file tree
Hide file tree
Showing 19 changed files with 860 additions and 94 deletions.
13 changes: 12 additions & 1 deletion doc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,22 @@ py_test_run_all_subdirectory(
exclude = [
"source/serve/doc_code/distilbert.py",
"source/serve/doc_code/stable_diffusion.py",
"source/serve/doc_code/object_detection.py",
"source/serve/doc_code/object_detection.py",
"source/serve/doc_code/streaming_example.py",
],
extra_srcs = [],
tags = ["exclusive", "team:serve"],
)

py_test_run_all_subdirectory(
size = "medium",
include = ["source/serve/doc_code/streaming_example.py"],
exclude = [],
extra_srcs = [],
tags = ["exclusive", "team:serve"],
env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"},
)

py_test_run_all_subdirectory(
size = "medium",
include = [
Expand All @@ -151,6 +161,7 @@ py_test_run_all_subdirectory(




# --------------------------------------------------------------------
# Test all doc/source/tune/doc_code code included in rst/md files.
# --------------------------------------------------------------------
Expand Down
40 changes: 40 additions & 0 deletions doc/source/serve/doc_code/streaming_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# flake8: noqa

# __begin_example__
import time
from typing import Generator

import requests
from starlette.responses import StreamingResponse
from starlette.requests import Request

from ray import serve


@serve.deployment
class StreamingResponder:
def generate_numbers(self, max: int) -> Generator[str, None, None]:
for i in range(max):
yield str(i)
time.sleep(0.1)

def __call__(self, request: Request) -> StreamingResponse:
max = request.query_params.get("max", "25")
gen = self.generate_numbers(int(max))
return StreamingResponse(gen, status_code=200, media_type="text/plain")


serve.run(StreamingResponder.bind())

r = requests.get("http://localhost:8000?max=10", stream=True)
start = time.time()
r.raise_for_status()
for chunk in r.iter_content(chunk_size=None, decode_unicode=True):
print(f"Got result {round(time.time()-start, 1)}s after start: '{chunk}'")
# __end_example__


r = requests.get("http://localhost:8000?max=10", stream=True)
r.raise_for_status()
for i, chunk in enumerate(r.iter_content(chunk_size=None, decode_unicode=True)):
assert chunk == str(i)
49 changes: 47 additions & 2 deletions doc/source/serve/http-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This section helps you understand how to:
- send HTTP requests to Serve deployments
- use Ray Serve to integrate with FastAPI
- use customized HTTP Adapters
- use customized HTTP adapters
- choose which feature to use for your use case

## Choosing the right HTTP feature
Expand Down Expand Up @@ -74,6 +74,51 @@ Existing middlewares, **automatic OpenAPI documentation generation**, and other
Serve currently does not support WebSockets. If you have a use case that requires it, please [let us know](https://github.com/ray-project/ray/issues/new/choose)!
```

(serve-http-streaming-response)=
## Streaming Responses

```{warning}
Support for HTTP streaming responses is experimental. To enable this feature, set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1` on the cluster before starting Ray. If you encounter any issues, [file an issue on GitHub](https://github.com/ray-project/ray/issues/new/choose).
```

Some applications must stream incremental results back to the caller.
This is common for text generation using large language models (LLMs) or video processing applications.
The full forward pass may take multiple seconds, so providing incremental results as they're available provides a much better user experience.

To use HTTP response streaming, return a [StreamingResponse](https://www.starlette.io/responses/#streamingresponse) that wraps a generator from your HTTP handler.
This is supported for basic HTTP ingress deployments using a `__call__` method and when using the [FastAPI integration](serve-fastapi-http).

The code below defines a Serve application that incrementally streams numbers up to a provided `max`.
The client-side code is also updated to handle the streaming outputs.
This code uses the `stream=True` option to the [requests](https://requests.readthedocs.io/en/latest/user/advanced/#streaming-requests) library.

```{literalinclude} ../serve/doc_code/streaming_example.py
:start-after: __begin_example__
:end-before: __end_example__
:language: python
```

Save this code in `stream.py` and run it:

```bash
$ RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 python stream.py
[2023-05-25 10:44:23] INFO ray._private.worker::Started a local Ray instance. View the dashboard at http://127.0.0.1:8265
(ServeController pid=40401) INFO 2023-05-25 10:44:25,296 controller 40401 deployment_state.py:1259 - Deploying new version of deployment default_StreamingResponder.
(HTTPProxyActor pid=40403) INFO: Started server process [40403]
(ServeController pid=40401) INFO 2023-05-25 10:44:25,333 controller 40401 deployment_state.py:1498 - Adding 1 replica to deployment default_StreamingResponder.
Got result 0.0s after start: '0'
Got result 0.1s after start: '1'
Got result 0.2s after start: '2'
Got result 0.3s after start: '3'
Got result 0.4s after start: '4'
Got result 0.5s after start: '5'
Got result 0.6s after start: '6'
Got result 0.7s after start: '7'
Got result 0.8s after start: '8'
Got result 0.9s after start: '9'
(ServeReplica:default_StreamingResponder pid=41052) INFO 2023-05-25 10:49:52,230 default_StreamingResponder default_StreamingResponder#qlZFCa yomKnJifNJ / default replica.py:634 - __CALL__ OK 1017.6ms
```

(serve-http-adapters)=

## HTTP Adapters
Expand Down Expand Up @@ -190,7 +235,7 @@ PredictorDeployment.deploy(..., http_adapter=User)
DAGDriver.bind(other_node, http_adapter=User)

```
### List of Built-in Adapters
### List of built-in adapters

Here is a list of adapters; please feel free to [contribute more](https://github.com/ray-project/ray/issues/new/choose)!

Expand Down
34 changes: 33 additions & 1 deletion python/ray/serve/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ py_test(
deps = [":serve_lib"],
)

py_test(
name = "test_http_util",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)

py_test(
name = "test_advanced",
size = "small",
Expand Down Expand Up @@ -448,6 +456,7 @@ py_test(
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)

# Runs test_api and test_failure with injected failures in the controller.
py_test(
name = "test_controller_crashes",
Expand All @@ -460,6 +469,29 @@ py_test(
deps = [":serve_lib"],
)

# Runs test_api, test_fastapi, and test_http_adapters with experimental streaming turned on.
py_test(
name = "test_experimental_streaming",
size = "large",
srcs = glob(["tests/test_experimental_streaming.py",
"tests/test_api.py",
"tests/test_fastapi.py",
"tests/test_http_adapters.py",
"**/conftest.py"]),
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"},
)

py_test(
name = "test_streaming_response",
size = "large",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"},
)

py_test(
name = "test_controller_recovery",
size = "medium",
Expand Down Expand Up @@ -581,4 +613,4 @@ py_test(
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)
)
7 changes: 7 additions & 0 deletions python/ray/serve/_private/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def get_handle(
missing_ok: Optional[bool] = False,
sync: bool = True,
_internal_pickled_http_request: bool = False,
_stream: bool = False,
) -> Union[RayServeHandle, RayServeSyncHandle]:
"""Retrieve RayServeHandle for service deployment to invoke it from Python.
Expand All @@ -450,6 +451,10 @@ def get_handle(
sync: If true, then Serve will return a ServeHandle that
works everywhere. Otherwise, Serve will return a ServeHandle
that's only usable in asyncio loop.
_internal_pickled_http_request: Indicates that this handle will be used
to send HTTP requests from the proxy to ingress deployment replicas.
_stream: Indicates that this handle should use
`num_returns="streaming"`.
Returns:
RayServeHandle
Expand All @@ -469,12 +474,14 @@ def get_handle(
self._controller,
deployment_name,
_internal_pickled_http_request=_internal_pickled_http_request,
_stream=_stream,
)
else:
handle = RayServeHandle(
self._controller,
deployment_name,
_internal_pickled_http_request=_internal_pickled_http_request,
_stream=_stream,
)

self.handle_cache[cache_key] = handle
Expand Down
6 changes: 6 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,9 @@ class ServeHandleType(str, Enum):

# Serve HTTP request header key for routing requests.
SERVE_MULTIPLEXED_MODEL_ID = "serve_multiplexed_model_id"

# Feature flag to enable StreamingResponse support.
# When turned on, *all* HTTP responses will use Ray streaming object refs.
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING = (
os.environ.get("RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING", "0") == "1"
)
74 changes: 69 additions & 5 deletions python/ray/serve/_private/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import pickle
import socket
import time
from typing import Callable, List, Dict, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from ray._private.utils import get_or_create_event_loop

import uvicorn
import starlette.responses
import starlette.routing
from starlette.types import Receive, Scope, Send

import ray
from ray.exceptions import RayActorError, RayTaskError
Expand All @@ -29,9 +30,10 @@
from ray.serve._private.common import EndpointInfo, EndpointTag, ApplicationName
from ray.serve._private.constants import (
SERVE_LOGGER_NAME,
SERVE_MULTIPLEXED_MODEL_ID,
SERVE_NAMESPACE,
DEFAULT_LATENCY_BUCKET_MS,
SERVE_MULTIPLEXED_MODEL_ID,
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING,
)
from ray.serve._private.long_poll import LongPollClient, LongPollNamespace
from ray.serve._private.logging_utils import (
Expand Down Expand Up @@ -72,6 +74,63 @@
)


async def _handle_streaming_response(
asgi_response_generator: "ray._raylet.StreamingObjectRefGenerator",
scope: Scope,
receive: Receive,
send: Send,
) -> str:
"""Consumes the `asgi_response_generator` and sends its data over `send`.
This function is a proxy for a downstream ASGI response. The passed
generator is expected to return a stream of pickled ASGI messages
(dictionaries) that are sent using the provided ASGI interface.
Exception handling depends on whether the first message has already been sent:
- if an exception happens *before* the first message, a 500 status is sent.
- if an exception happens *after* the first message, the response stream is
terminated.
The difference in behavior is because once the first message has been sent, the
client has already received the status code so we cannot send a `500` (internal
server error).
Returns:
status_code
"""

status_code = ""
try:
async for obj_ref in asgi_response_generator:
asgi_messages: List[Dict[str, Any]] = pickle.loads(await obj_ref)
for asgi_message in asgi_messages:
# There must be exactly one "http.response.start" message that
# always contains the "status" field.
if not status_code:
assert asgi_message["type"] == "http.response.start", (
"First response message must be 'http.response.start'",
)
assert "status" in asgi_message, (
"'http.response.start' message must contain 'status'",
)
status_code = str(asgi_message["status"])

await send(asgi_message)
except Exception as e:
error_message = f"Unexpected error, traceback: {e}."
logger.warning(error_message)

if status_code == "":
# If first message hasn't been sent, return 500 status.
await Response(error_message, status_code=500).send(scope, receive, send)
return "500"
else:
# If first message has been sent, terminate the response stream.
return status_code

return status_code


async def _send_request_to_handle(handle, scope, receive, send) -> str:
http_body_bytes = await receive_http_body(scope, receive, send)

Expand Down Expand Up @@ -112,6 +171,11 @@ async def _send_request_to_handle(handle, scope, receive, send) -> str:
try:
object_ref = await assignment_task

if isinstance(object_ref, ray._raylet.StreamingObjectRefGenerator):
return await _handle_streaming_response(
object_ref, scope, receive, send
)

# NOTE (shrekris-anyscale): when the gcs, Serve controller, and
# some replicas crash simultaneously (e.g. if the head node crashes),
# requests to the dead replicas hang until the gcs recovers.
Expand Down Expand Up @@ -139,8 +203,8 @@ async def _send_request_to_handle(handle, scope, receive, send) -> str:
# Here because the client disconnected, we will return a custom
# error code for metric tracking.
return DISCONNECT_ERROR_CODE
except RayTaskError as error:
error_message = "Task Error. Traceback: {}.".format(error)
except RayTaskError as e:
error_message = f"Unexpected error, traceback: {e}."
await Response(error_message, status_code=500).send(scope, receive, send)
return "500"
except RayActorError:
Expand Down Expand Up @@ -277,6 +341,7 @@ def get_handle(name):
sync=False,
missing_ok=True,
_internal_pickled_http_request=True,
_stream=RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING,
)

self.prefix_router = LongestPrefixRouter(get_handle)
Expand Down Expand Up @@ -439,7 +504,6 @@ async def __call__(self, scope, receive, send):
ray.serve.context.RequestContext(**request_context_info)
)
status_code = await _send_request_to_handle(handle, scope, receive, send)

self.request_counter.inc(
tags={
"route": route_path,
Expand Down
Loading

0 comments on commit 6ae6920

Please sign in to comment.