From 605f2302c9cbc4a968932715c7d67c3539bdd879 Mon Sep 17 00:00:00 2001 From: shrekris-anyscale <92341594+shrekris-anyscale@users.noreply.github.com> Date: Thu, 6 Jul 2023 09:26:43 -0700 Subject: [PATCH] [Serve] [Docs] Add end-to-end documentation for streaming (#36961) Serve has recently added streaming and WebSocket support. This change adds end-to-end examples to guide users through these features. Link to documentation: https://anyscale-ray--36961.com.readthedocs.build/en/36961/serve/tutorials/streaming.html Co-authored-by: Edward Oakes Co-authored-by: angelinalg <122562471+angelinalg@users.noreply.github.com> --- .buildkite/pipeline.ml.yml | 3 + .../serve/advanced-guides/dyn-req-batch.md | 2 + .../serve/doc_code/streaming_tutorial.py | 348 ++++++++++++++++++ doc/source/serve/tutorials/index.md | 1 + doc/source/serve/tutorials/streaming.md | 211 +++++++++++ 5 files changed, 565 insertions(+) create mode 100644 doc/source/serve/doc_code/streaming_tutorial.py create mode 100644 doc/source/serve/tutorials/streaming.md diff --git a/.buildkite/pipeline.ml.yml b/.buildkite/pipeline.ml.yml index e177dd7c0a6c..75404bb3eafb 100644 --- a/.buildkite/pipeline.ml.yml +++ b/.buildkite/pipeline.ml.yml @@ -490,6 +490,9 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT - DOC_TESTING=1 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh + # TODO (shrekris-anyscale): Remove transformers after core transformer + # requirement is upgraded + - pip install "transformers==4.30.2" - ./ci/env/env_info.sh - bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-timeseries_libs,-external,-ray_air,-gpu,-post_wheel_build,-doctest,-datasets_train,-highly_parallel doc/... diff --git a/doc/source/serve/advanced-guides/dyn-req-batch.md b/doc/source/serve/advanced-guides/dyn-req-batch.md index 666434eec9d3..98beeb96a2ae 100644 --- a/doc/source/serve/advanced-guides/dyn-req-batch.md +++ b/doc/source/serve/advanced-guides/dyn-req-batch.md @@ -47,6 +47,8 @@ end-before: __batch_params_update_end__ Use these methods in the `reconfigure` [method](serve-user-config) to control the `@serve.batch` parameters through your Serve configuration file. ::: +(serve-streaming-batched-requests-guide)= + ## Streaming batched requests ```{warning} diff --git a/doc/source/serve/doc_code/streaming_tutorial.py b/doc/source/serve/doc_code/streaming_tutorial.py new file mode 100644 index 000000000000..bd23858e95f1 --- /dev/null +++ b/doc/source/serve/doc_code/streaming_tutorial.py @@ -0,0 +1,348 @@ +# flake8: noqa +# fmt: off + +from typing import List + +# __textbot_setup_start__ +import asyncio +import logging +from queue import Empty + +from fastapi import FastAPI +from starlette.responses import StreamingResponse +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + +from ray import serve + +logger = logging.getLogger("ray.serve") +# __textbot_setup_end__ + + +# __textbot_constructor_start__ +fastapi_app = FastAPI() + + +@serve.deployment +@serve.ingress(fastapi_app) +class Textbot: + def __init__(self, model_id: str): + self.loop = asyncio.get_running_loop() + + self.model_id = model_id + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + # __textbot_constructor_end__ + + # __textbot_logic_start__ + @fastapi_app.post("/") + def handle_request(self, prompt: str) -> StreamingResponse: + logger.info(f'Got prompt: "{prompt}"') + streamer = TextIteratorStreamer( + self.tokenizer, timeout=0, skip_prompt=True, skip_special_tokens=True + ) + self.loop.run_in_executor(None, self.generate_text, prompt, streamer) + return StreamingResponse( + self.consume_streamer(streamer), media_type="text/plain" + ) + + def generate_text(self, prompt: str, streamer: TextIteratorStreamer): + input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids + self.model.generate(input_ids, streamer=streamer, max_length=10000) + + async def consume_streamer(self, streamer: TextIteratorStreamer): + while True: + try: + for token in streamer: + logger.info(f'Yielding token: "{token}"') + yield token + break + except Empty: + # The streamer raises an Empty exception if the next token + # hasn't been generated yet. `await` here to yield control + # back to the event loop so other coroutines can run. + await asyncio.sleep(0.001) + + # __textbot_logic_end__ + + +# __textbot_bind_start__ +app = Textbot.bind("microsoft/DialoGPT-small") +# __textbot_bind_end__ + + +serve.run(app) + +chunks = [] +# __stream_client_start__ +import requests + +prompt = "Tell me a story about dogs." + +response = requests.post(f"http://localhost:8000/?prompt={prompt}", stream=True) +response.raise_for_status() +for chunk in response.iter_content(chunk_size=None, decode_unicode=True): + print(chunk, end="") + + # Dogs are the best. + # __stream_client_end__ + chunks.append(chunk) + +# Check that streaming is happening. +assert chunks == ["Dogs ", "are ", "the ", "best."] + + +# __chatbot_setup_start__ +import asyncio +import logging +from queue import Empty + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + +from ray import serve + +logger = logging.getLogger("ray.serve") +# __chatbot_setup_end__ + + +# __chatbot_constructor_start__ +fastapi_app = FastAPI() + + +@serve.deployment +@serve.ingress(fastapi_app) +class Chatbot: + def __init__(self, model_id: str): + self.loop = asyncio.get_running_loop() + + self.model_id = model_id + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + # __chatbot_constructor_end__ + + # __chatbot_logic_start__ + @fastapi_app.websocket("/") + async def handle_request(self, ws: WebSocket) -> None: + await ws.accept() + + conversation = "" + try: + while True: + prompt = await ws.receive_text() + logger.info(f'Got prompt: "{prompt}"') + conversation += prompt + streamer = TextIteratorStreamer( + self.tokenizer, + timeout=0, + skip_prompt=True, + skip_special_tokens=True, + ) + self.loop.run_in_executor( + None, self.generate_text, conversation, streamer + ) + response = "" + async for text in self.consume_streamer(streamer): + await ws.send_text(text) + response += text + await ws.send_text("<>") + conversation += response + except WebSocketDisconnect: + print("Client disconnected.") + + def generate_text(self, prompt: str, streamer: TextIteratorStreamer): + input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids + self.model.generate(input_ids, streamer=streamer, max_length=10000) + + async def consume_streamer(self, streamer: TextIteratorStreamer): + while True: + try: + for token in streamer: + logger.info(f'Yielding token: "{token}"') + yield token + break + except Empty: + await asyncio.sleep(0.001) + + +# __chatbot_logic_end__ + + +# __chatbot_bind_start__ +app = Chatbot.bind("microsoft/DialoGPT-small") +# __chatbot_bind_end__ + +serve.run(app) + +chunks = [] +# Monkeypatch `print` for testing +original_print, print = print, (lambda chunk, end=None: chunks.append(chunk)) + +# __ws_client_start__ +from websockets.sync.client import connect + +with connect("ws://localhost:8000") as websocket: + websocket.send("Space the final") + while True: + received = websocket.recv() + if received == "<>": + break + print(received, end="") + print("\n") + + websocket.send(" These are the voyages") + while True: + received = websocket.recv() + if received == "<>": + break + print(received, end="") + print("\n") +# __ws_client_end__ + +assert chunks == [ + " ", + "", + "", + "frontier.", + "\n", + " ", + "of ", + "the ", + "starship ", + "", + "", + "Enterprise.", + "\n", +] + +print = original_print + +# __batchbot_setup_start__ +import asyncio +import logging +from queue import Empty, Queue + +from fastapi import FastAPI +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ray import serve + +logger = logging.getLogger("ray.serve") +# __batchbot_setup_end__ + +# __raw_streamer_start__ +class RawStreamer: + def __init__(self, timeout: float = None): + self.q = Queue() + self.stop_signal = None + self.timeout = timeout + + def put(self, values): + self.q.put(values) + + def end(self): + self.q.put(self.stop_signal) + + def __iter__(self): + return self + + def __next__(self): + result = self.q.get(timeout=self.timeout) + if result == self.stop_signal: + raise StopIteration() + else: + return result + + +# __raw_streamer_end__ + +# __batchbot_constructor_start__ +fastapi_app = FastAPI() + + +@serve.deployment +@serve.ingress(fastapi_app) +class Batchbot: + def __init__(self, model_id: str): + self.loop = asyncio.get_running_loop() + + self.model_id = model_id + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + # __batchbot_constructor_end__ + + # __batchbot_logic_start__ + @fastapi_app.post("/") + async def handle_request(self, prompt: str) -> StreamingResponse: + logger.info(f'Got prompt: "{prompt}"') + return StreamingResponse(self.run_model(prompt), media_type="text/plain") + + @serve.batch(max_batch_size=2, batch_wait_timeout_s=15) + async def run_model(self, prompts: List[str]): + streamer = RawStreamer() + self.loop.run_in_executor(None, self.generate_text, prompts, streamer) + on_prompt_tokens = True + async for decoded_token_batch in self.consume_streamer(streamer): + # The first batch of tokens contains the prompts, so we skip it. + if not on_prompt_tokens: + logger.info(f"Yielding decoded_token_batch: {decoded_token_batch}") + yield decoded_token_batch + else: + logger.info(f"Skipped prompts: {decoded_token_batch}") + on_prompt_tokens = False + + def generate_text(self, prompts: str, streamer: RawStreamer): + input_ids = self.tokenizer(prompts, return_tensors="pt", padding=True).input_ids + self.model.generate(input_ids, streamer=streamer, max_length=10000) + + async def consume_streamer(self, streamer: RawStreamer): + while True: + try: + for token_batch in streamer: + decoded_tokens = [] + for token in token_batch: + decoded_tokens.append( + self.tokenizer.decode(token, skip_special_tokens=True) + ) + logger.info(f"Yielding decoded tokens: {decoded_tokens}") + yield decoded_tokens + break + except Empty: + await asyncio.sleep(0.001) + + +# __batchbot_logic_end__ + + +# __batchbot_bind_start__ +app = Batchbot.bind("microsoft/DialoGPT-small") +# __batchbot_bind_end__ + +serve.run(app) + +# Test batching code +from functools import partial +from concurrent.futures.thread import ThreadPoolExecutor + + +def get_buffered_response(prompt) -> List[str]: + response = requests.post(f"http://localhost:8000/?prompt={prompt}", stream=True) + chunks = [] + for chunk in response.iter_content(chunk_size=None, decode_unicode=True): + chunks.append(chunk) + return chunks + + +with ThreadPoolExecutor() as pool: + futs = [ + pool.submit(partial(get_buffered_response, prompt)) + for prompt in ["Introduce yourself to me!", "Tell me a story about dogs."] + ] + responses = [fut.result() for fut in futs] + assert responses == [ + ["I", "'m", " not", " sure", " if", " I", "'m", " ready", " for", " that", "."], + ["D", "ogs", " are", " the", " best", "."], + ] diff --git a/doc/source/serve/tutorials/index.md b/doc/source/serve/tutorials/index.md index 22737b9423a2..cd350af78f87 100644 --- a/doc/source/serve/tutorials/index.md +++ b/doc/source/serve/tutorials/index.md @@ -16,6 +16,7 @@ object-detection rllib gradio-integration batch +streaming gradio-dag-visualization java ``` diff --git a/doc/source/serve/tutorials/streaming.md b/doc/source/serve/tutorials/streaming.md new file mode 100644 index 000000000000..6c2a4d061562 --- /dev/null +++ b/doc/source/serve/tutorials/streaming.md @@ -0,0 +1,211 @@ +(serve-streaming-tutorial)= + +# Streaming Tutorial + +:::{warning} +Support for streaming is experimental. To enable this feature, set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1` on the cluster before starting Ray. If you encounter any issues, [file an issue on GitHub](https://github.com/ray-project/ray/issues/new/choose). +::: + +This guide walks you through deploying a chatbot that streams output back to the +user. It shows: + +* How to stream outputs from a Serve application +* How to use WebSockets in a Serve application +* How to combine batching requests with streaming outputs + +This tutorial should help you with following use cases: + +* You want to serve a large language model and stream results back token-by-token. +* You want to serve a chatbot that accepts a stream of inputs from the user. + +This tutorial serves the [DialoGPT](https://huggingface.co/microsoft/DialoGPT-small) language model. Install the HuggingFace library to access it: + +``` +pip install transformers +``` + +# Create a Streaming Deployment + +Open a new Python file called `textbot.py`. First, add the imports and the [Serve logger](serve-logging). + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __textbot_setup_start__ +:end-before: __textbot_setup_end__ +``` + +Create a [FastAPI deployment](serve-fastapi-http), and initialize the model and the tokenizer in the +constructor: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __textbot_constructor_start__ +:end-before: __textbot_constructor_end__ +``` + +Note that the constructor also caches an `asyncio` loop. This behavior is useful when you need to run a model and concurrently stream its tokens back to the user. + +Add the following logic to handle requests sent to the `Textbot`: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __textbot_logic_start__ +:end-before: __textbot_logic_end__ +``` + +`Textbot` uses three methods to handle requests: + +* `handle_request`: the entrypoint for HTTP requests. FastAPI automatically unpacks the `prompt` query parameter and passes it into `handle_request`. This method then creates a `TextIteratorStreamer`. HuggingFace provides this streamer as a convenient interface to access tokens generated by a language model. `handle_request` then kicks off the model in a background thread using `self.loop.run_in_executor`. This behavior lets the model generate tokens while `handle_request` concurrently calls `self.consume_streamer` to stream the tokens back to the user. `self.consume_streamer` is a generator that yields tokens one by one from the streamer. Lastly, `handle_request` passes the `self.consume_streamer` generator into a Starlette `StreamingResponse` and returns the response. Serve unpacks the Starlette `StreamingResponse` and yields the contents of the generator back to the user one by one. +* `generate_text`: the method that runs the model. This method runs in a background thread kicked off by `handle_request`. It pushes generated tokens into the streamer constructed by `handle_request`. +* `consume_streamer`: a generator method that consumes the streamer constructed by `handle_request`. This method keeps yielding tokens from the streamer until the model in `generate_text` closes the streamer. This method avoids blocking the event loop by calling `asyncio.sleep` with a brief timeout whenever the streamer is empty and waiting for a new token. + +Bind the `Textbot` to a language model. For this tutorial, use the `"microsoft/DialoGPT-small"` model: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __textbot_bind_start__ +:end-before: __textbot_bind_end__ +``` + +Run the model with `serve run textbot:app`, and query it from another terminal window with this script: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __stream_client_start__ +:end-before: __stream_client_end__ +``` + +You should see the output printed token by token. + +# Stream inputs and outputs using WebSockets + +WebSockets let you stream input into the application and stream output back to the client. Use WebSockets to create a chatbot that stores a conversation with a user. + +Create a Python file called `chatbot.py`. First add the imports: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __chatbot_setup_start__ +:end-before: __chatbot_setup_end__ +``` + +Create a FastAPI deployment, and initialize the model and the tokenizer in the +constructor: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __chatbot_constructor_start__ +:end-before: __chatbot_constructor_end__ +``` + +Add the following logic to handle requests sent to the `Chatbot`: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __chatbot_logic_start__ +:end-before: __chatbot_logic_end__ +``` + +The `generate_text` and `consume_streamer` methods are the same as they were for the `Textbot`. The `handle_request` method has been updated to handle WebSocket requests. + +The `handle_request` method is decorated with a `fastapi_app.websocket` decorator, which lets it accept WebSocket requests. First it `awaits` to accept the client's WebSocket request. Then, until the client disconnects, it does the following: + +* gets the prompt from the client with `ws.receive_text` +* starts a new `TextIteratorStreamer` to access generated tokens +* runs the model in a background thread on the conversation so far +* streams the model's output back using `ws.send_text` +* stores the prompt and the response in the `conversation` string + +Each time `handle_request` gets a new prompt from a client, it runs the whole conversation– with the new prompt appended– through the model. When the model is finished generating tokens, `handle_request` sends the `"<>"` string to inform the client that all tokens have been generated. `handle_request` continues to run until the client explicitly disconnects. This disconnect raises a `WebSocketDisconnect` exception, which ends the call. + +Read more about WebSockets in the [FastAPI documentation](https://fastapi.tiangolo.com/advanced/websockets/). + +Bind the `Chatbot` to a language model. For this tutorial, use the `"microsoft/DialoGPT-small"` model: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __chatbot_bind_start__ +:end-before: __chatbot_bind_end__ +``` + +Run the model with `serve run chatbot:app`. Query it using the `websockets` package (`pip install websockets`): + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __ws_client_start__ +:end-before: __ws_client_end__ +``` + +You should see the outputs printed token by token. + +# Batch requests and stream the output for each + +Improve model utilization and request latency by batching requests together when running the model. + +Create a Python file called `batchbot.py`. First add the imports: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __batchbot_setup_start__ +:end-before: __batchbot_setup_end__ +``` + +:::{warning} +HuggingFace's support for `Streamers` is still under development and may change in the future. `RawQueue` is compatible with the `Streamers` interface in HuggingFace 4.30.2. However, the `Streamers` interface may change, making the `RawQueue` incompatible with HuggingFace models in the future. +::: + +Just like the `Textbot` and `Chatbot`, the `Batchbot` needs a streamer to stream outputs from batched requests, but HuggingFace `Streamers` don't support batched requests yet. Add this custom `RawStreamer` to process batches of tokens: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __raw_streamer_start__ +:end-before: __raw_streamer_end__ +``` + +Create a FastAPI deployment, and initialize the model and the tokenizer in the +constructor: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __batchbot_constructor_start__ +:end-before: __batchbot_constructor_end__ +``` + +Unlike `Textbot` and `Chatbot`, the `Batchbot` constructor also sets a `pad_token`. This token needs to be set to batch prompts with different lengths. + +Add the following logic to handle requests sent to the `Batchbot`: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __batchbot_logic_start__ +:end-before: __batchbot_logic_end__ +``` + +`Batchbot` uses four methods to handle requests: + +* `handle_request`: the entrypoint method. This method simply takes in the request's prompt and calls the `run_model` method on it. `run_model` is a generator method that also handles batching the requests. `handle_request` passes `run_model` into a Starlette `StreamingResponse` and returns the response, so generated tokens can be streamed back to the client. +* `run_model`: a generator method that performs batching. Since `run_model` is decorated with `@serve.batch`, it automatically takes in a batch of prompts. See the [batching guide](serve-batch-tutorial) for more info. `run_model` creates a `RawStreamer` to access the generated tokens. It calls `generate_text` in a background thread, and passes in the `prompts` and the `streamer`, similar to the `Textbot`. Then it iterates through the `consume_streamer` generator, repeatedly yielding a batch of tokens generated by the model. +* `generate_text`: the method that runs the model. It's mostly the same as `generate_text` in `Textbot`, with two differences. First, it takes in and processes a batch of prompts instead of a single prompt. Second, it sets `padding=True`, so prompts with different lengths can be batched together. +* `consume_streamer`: a generator method that consumes the streamer constructed by `handle_request`. It's mostly the same as `consume_streamer` in `Textbot`, with one difference. It uses the `tokenizer` to decode the generated tokens. Usually, this is handled by the HuggingFace streamer. Since this implementation uses the custom `RawStreamer`, `consume_streamer` must handle the decoding. + +:::{tip} +Some inputs within a batch may generate fewer outputs than others. When a particular input has nothing left to yield, pass a `StopIteration` object into the output iterable to terminate that input's request. See [this section](serve-streaming-batched-requests-guide) for more info. +::: + +Bind the `Batchbot` to a language model. For this tutorial, use the `"microsoft/DialoGPT-small"` model: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __batchbot_bind_start__ +:end-before: __batchbot_bind_end__ +``` + +Run the model with `serve run batchbot:app`. Query it from two other terminal windows with this script: + +```{literalinclude} ../doc_code/streaming_tutorial.py +:language: python +:start-after: __stream_client_start__ +:end-before: __stream_client_end__ +``` + +You should see the output printed token by token in both windows.