Skip to content

Commit

Permalink
[Serve] Make max_batch_size and batch_wait_timeout_s reconfigurab…
Browse files Browse the repository at this point in the history
…le (#36881)

The `@serve.batch` decorator has two parameters: `max_batch_size` and `batch_wait_timeout_s`. These parameters can be set in the decorator. However, they cannot be reconfigured after the Serve application starts.

This change adds two setter methods: `set_max_batch_size` and `set_batch_wait_timeout_s`. Users can reconfigure their `@serve.batch` parameters using these methods:

```python
@serve.batch(max_batch_size=1, batch_wait_timeout_s=0.1)
def batch_handler(self, request_list):
    ...

self.batch_handler.set_max_batch_size(5)
self.batch_handler.set_batch_wait_timeout_s(0.5)
```

Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Co-authored-by: angelinalg <122562471+angelinalg@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 28, 2023
1 parent 960032a commit 39195ff
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 36 deletions.
13 changes: 13 additions & 0 deletions doc/source/serve/advanced-guides/dyn-req-batch.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ You can supply two optional parameters to the decorators.
- `max_batch_size` controls the size of the batch.
Once the first request arrives, the batching decorator will wait for a full batch (up to `max_batch_size`) until `batch_wait_timeout_s` is reached. If the timeout is reached, the batch will be sent to the model regardless the batch size.

:::{tip}
You can reconfigure your `batch_wait_timeout_s` and `max_batch_size` parameters using the `set_batch_wait_timeout_s` and `set_max_batch_size` methods:

```{literalinclude} ../doc_code/batching_guide.py
---
start-after: __batch_params_update_begin__
end-before: __batch_params_update_end__
---
```

Use these methods in the `reconfigure` [method](serve-in-production-reconfigure) to control the `@serve.batch` parameters through your Serve configuration file.
:::

## Streaming batched requests

```{warning}
Expand Down
18 changes: 18 additions & 0 deletions doc/source/serve/doc_code/batching_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ async def __call__(self, multiple_samples: List[int]) -> List[int]:
# __batch_end__


# __batch_params_update_begin__
@serve.deployment
class Model:
@serve.batch(max_batch_size=8, batch_wait_timeout_s=0.1)
async def __call__(self, multiple_samples: List[int]) -> List[int]:
# Use numpy's vectorized computation to efficiently process a batch.
return np.array(multiple_samples) * 2

def adjust_batch_parameters(
self, new_max_batch_size: int, new_batch_wait_timeout_s: float
):
self.__call__.set_max_batch_size(new_max_batch_size)
self.__call__.set_batch_wait_timeout_s(new_batch_wait_timeout_s)


# __batch_params_update_end__


# __single_stream_begin__
import asyncio
from typing import AsyncGenerator
Expand Down
2 changes: 2 additions & 0 deletions doc/source/serve/production-guide/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ Serve will set `num_replicas=5`, using the config file value, and `max_concurren
Remember that `ray_actor_options` counts as a single setting. The entire `ray_actor_options` dictionary in the config file overrides the entire `ray_actor_options` dictionary from the graph code. If there are individual options within `ray_actor_options` (e.g. `runtime_env`, `num_gpus`, `memory`) that are set in the code but not in the config, Serve still won't use the code settings if the config has a `ray_actor_options` dictionary. It will treat these missing options as though the user never set them and will use defaults instead. This dictionary overriding behavior also applies to `user_config`.
:::

(serve-in-production-reconfigure)=

## Dynamically adjusting parameters in deployment

The `user_config` field can be used to supply structured configuration for your deployment. You can pass arbitrary JSON serializable objects to the YAML configuration. Serve will then apply it to all running and future deployment replicas. The application of user configuration *will not* restart the replica. This means you can use this field to dynamically:
Expand Down
123 changes: 103 additions & 20 deletions python/ray/serve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class _BatchQueue:
def __init__(
self,
max_batch_size: int,
timeout_s: float,
batch_wait_timeout_s: float,
handle_batch_func: Optional[Callable] = None,
) -> None:
"""Async queue that accepts individual items and returns batches.
Expand All @@ -78,6 +78,8 @@ def __init__(
If handle_batch_func is passed in, a background coroutine will run to
poll from the queue and call handle_batch_func on the results.
Cannot be pickled.
Arguments:
max_batch_size: max number of elements to return in a batch.
timeout_s: time to wait before returning an incomplete
Expand All @@ -87,7 +89,7 @@ def __init__(
"""
self.queue: asyncio.Queue[_SingleRequest] = asyncio.Queue()
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
self.batch_wait_timeout_s = batch_wait_timeout_s
self.queue_put_event = asyncio.Event()

self._handle_batch_task = None
Expand All @@ -114,11 +116,15 @@ async def wait_for_batch(self) -> List[Any]:
batch = []
batch.append(await self.queue.get())

# Cache current max_batch_size and batch_wait_timeout_s for this batch.
max_batch_size = self.max_batch_size
batch_wait_timeout_s = self.batch_wait_timeout_s

# Wait self.timeout_s seconds for new queue arrivals.
batch_start_time = time.time()
while True:
remaining_batch_time_s = max(
self.timeout_s - (time.time() - batch_start_time), 0
batch_wait_timeout_s - (time.time() - batch_start_time), 0
)
try:
# Wait for new arrivals.
Expand All @@ -129,13 +135,13 @@ async def wait_for_batch(self) -> List[Any]:
pass

# Add all new arrivals to the batch.
while len(batch) < self.max_batch_size and not self.queue.empty():
while len(batch) < max_batch_size and not self.queue.empty():
batch.append(self.queue.get_nowait())
self.queue_put_event.clear()

if (
time.time() - batch_start_time >= self.timeout_s
or len(batch) >= self.max_batch_size
time.time() - batch_start_time >= batch_wait_timeout_s
or len(batch) >= max_batch_size
):
break

Expand Down Expand Up @@ -237,6 +243,62 @@ def __del__(self):
self._handle_batch_task.cancel()


class _LazyBatchQueueWrapper:
"""Stores a _BatchQueue and updates its settings.
_BatchQueue cannot be pickled, you must construct it lazily
at runtime inside a replica. This class initializes a queue only upon
first access.
"""

def __init__(
self,
max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.0,
handle_batch_func: Optional[Callable] = None,
batch_queue_cls: Type[_BatchQueue] = _BatchQueue,
):
self._queue: Type[_BatchQueue] = None
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
self.handle_batch_func = handle_batch_func
self.batch_queue_cls = batch_queue_cls

@property
def queue(self) -> Type[_BatchQueue]:
"""Returns _BatchQueue.
Initializes queue when called for the first time.
"""
if self._queue is None:
self._queue = self.batch_queue_cls(
self.max_batch_size,
self.batch_wait_timeout_s,
self.handle_batch_func,
)
return self._queue

def set_max_batch_size(self, new_max_batch_size: int) -> None:
"""Updates queue's max_batch_size."""

self.max_batch_size = new_max_batch_size

if self._queue is not None:
self._queue.max_batch_size = new_max_batch_size

def set_batch_wait_timeout_s(self, new_batch_wait_timeout_s: float) -> None:
self.batch_wait_timeout_s = new_batch_wait_timeout_s

if self._queue is not None:
self._queue.batch_wait_timeout_s = new_batch_wait_timeout_s

def get_max_batch_size(self) -> int:
return self.max_batch_size

def get_batch_wait_timeout_s(self) -> float:
return self.batch_wait_timeout_s


def _validate_max_batch_size(max_batch_size):
if not isinstance(max_batch_size, int):
if isinstance(max_batch_size, float) and max_batch_size.is_integer():
Expand Down Expand Up @@ -303,6 +365,10 @@ def batch(
and executed asynchronously once there is a batch of `max_batch_size`
or `batch_wait_timeout_s` has elapsed, whichever occurs first.
`max_batch_size` and `batch_wait_timeout_s` can be updated using setter
methods from the batch_handler (`set_max_batch_size` and
`set_batch_wait_timeout_s`).
Example:
.. code-block:: python
Expand All @@ -321,6 +387,10 @@ async def batch_handler(self, requests: List[Request]) -> List[str]:
return response_batch
def update_batch_params(self, max_batch_size, batch_wait_timeout_s):
self.batch_handler.set_max_batch_size(max_batch_size)
self.batch_handler.set_batch_wait_timeout_s(batch_wait_timeout_s)
async def __call__(self, request: Request):
return await self.batch_handler(request)
Expand Down Expand Up @@ -348,6 +418,13 @@ async def __call__(self, request: Request):
_validate_batch_wait_timeout_s(batch_wait_timeout_s)

def _batch_decorator(_func):
lazy_batch_queue_wrapper = _LazyBatchQueueWrapper(
max_batch_size,
batch_wait_timeout_s,
_func,
batch_queue_cls,
)

async def batch_handler_generator(
first_future: asyncio.Future,
) -> AsyncGenerator:
Expand Down Expand Up @@ -377,17 +454,7 @@ def enqueue_request(args, kwargs) -> asyncio.Future:
# Trim the self argument from methods
flattened_args = flattened_args[2:]

# The first time the function runs, we lazily construct the batch
# queue and inject it under a custom attribute name. On subsequent
# runs, we just get a reference to the attribute.
batch_queue_attr = f"__serve_batch_queue_{_func.__name__}"
if not hasattr(batch_queue_object, batch_queue_attr):
batch_queue = batch_queue_cls(
max_batch_size, batch_wait_timeout_s, _func
)
setattr(batch_queue_object, batch_queue_attr, batch_queue)
else:
batch_queue = getattr(batch_queue_object, batch_queue_attr)
batch_queue = lazy_batch_queue_wrapper.queue

# Magic batch_queue_object attributes that can be used to change the
# batch queue attributes on the fly.
Expand All @@ -405,12 +472,15 @@ def enqueue_request(args, kwargs) -> asyncio.Future:
batch_queue_object, "_ray_serve_batch_wait_timeout_s"
)
_validate_batch_wait_timeout_s(new_batch_wait_timeout_s)
batch_queue.timeout_s = new_batch_wait_timeout_s
batch_queue.batch_wait_timeout_s = new_batch_wait_timeout_s

future = get_or_create_event_loop().create_future()
batch_queue.put(_SingleRequest(self, flattened_args, future))
return future

# TODO (shrekris-anyscale): deprecate batch_queue_cls argument and
# convert batch_wrapper into a class once `self` argument is no
# longer needed in `enqueue_request`.
@wraps(_func)
def generator_batch_wrapper(*args, **kwargs):
first_future = enqueue_request(args, kwargs)
Expand All @@ -422,9 +492,22 @@ async def batch_wrapper(*args, **kwargs):
return await enqueue_request(args, kwargs)

if isasyncgenfunction(_func):
return generator_batch_wrapper
wrapper = generator_batch_wrapper
else:
return batch_wrapper
wrapper = batch_wrapper

# We store the lazy_batch_queue_wrapper's getters and setters as
# batch_wrapper attributes, so they can be accessed in user code.
wrapper._get_max_batch_size = lazy_batch_queue_wrapper.get_max_batch_size
wrapper._get_batch_wait_timeout_s = (
lazy_batch_queue_wrapper.get_batch_wait_timeout_s
)
wrapper.set_max_batch_size = lazy_batch_queue_wrapper.set_max_batch_size
wrapper.set_batch_wait_timeout_s = (
lazy_batch_queue_wrapper.set_batch_wait_timeout_s
)

return wrapper

# Unfortunately, this is required to handle both non-parametrized
# (@serve.batch) and parametrized (@serve.batch(**kwargs)) usage.
Expand Down
Loading

0 comments on commit 39195ff

Please sign in to comment.