Skip to content

Commit

Permalink
[Serve] Multiplex API Impl (ray-project#35326)
Browse files Browse the repository at this point in the history
Adds @serve.multiplexed and @serve.get_multiplexed_model_id implementation.

Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
sihanwang41 authored and arvind-chandra committed Aug 31, 2023
1 parent dd96a40 commit 6a5dfd9
Show file tree
Hide file tree
Showing 6 changed files with 356 additions and 37 deletions.
36 changes: 35 additions & 1 deletion python/ray/serve/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
import traceback
from enum import Enum
from functools import wraps
from typing import Dict, Iterable, List, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Tuple,
TypeVar,
Union,
Optional,
)

import fastapi.encoders
import numpy as np
Expand Down Expand Up @@ -535,3 +545,27 @@ def record_serve_tag(key: str, value: str):
)

record_extra_usage_tag(serve_telemetry_tag_map[key], value)


def extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[object]:
"""Check if this is a method rather than a function.
Does this by checking to see if `func` is the attribute of the first
(`self`) argument under `func.__name__`. Unfortunately, this is the most
robust solution to this I was able to find. It would also be preferable
to do this check when the decorator runs, rather than when the method is.
Returns the `self` object if it's a method call, else None.
Arguments:
args: arguments to the function/method call.
func: the unbound function that was called.
"""
if len(args) > 0:
method = getattr(args[0], func.__name__, False)
if method:
wrapped = getattr(method, "__wrapped__", False)
if wrapped and wrapped == func:
return args[0]

return None
80 changes: 78 additions & 2 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import inspect
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union
from functools import wraps

from fastapi import APIRouter, FastAPI
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from starlette.requests import Request
from uvicorn.config import Config
from uvicorn.lifespan.on import LifespanOn

import ray
from ray import cloudpickle
from ray.dag import DAGNode
from ray.util.annotations import Deprecated, PublicAPI
Expand All @@ -29,6 +31,7 @@
_set_global_client,
)
from ray.serve.deployment import Application, Deployment
from ray.serve.multiplex import _ModelMultiplexWrapper
from ray.serve._private.deployment_graph_build import build as pipeline_build
from ray.serve._private.deployment_graph_build import (
get_and_validate_ingress_deployment,
Expand All @@ -45,6 +48,7 @@
install_serve_encoders_to_fastapi,
guarded_deprecation_warning,
record_serve_tag,
extract_self_if_method_call,
)

from ray.serve._private import api as _private_api
Expand Down Expand Up @@ -638,7 +642,78 @@ async def __call__(self, request):
number if you want to save memory on the node resource.
"""

raise NotImplementedError("Multiplexed deployment is not supported yet.")
if func is not None:
if not callable(func):
raise TypeError(
"The `multiplexed` decorator must be used with a function or method."
)

# TODO(Sihan): Make the API accept the sync function as well.
# https://github.com/ray-project/ray/issues/35356
if not inspect.iscoroutinefunction(func):
raise TypeError(
"@serve.multiplexed can only be used to decorate async "
"functions or methods."
)
signature = inspect.signature(func)
if len(signature.parameters) == 0 or len(signature.parameters) > 2:
raise TypeError(
"@serve.multiplexed can only be used to decorate functions or methods "
"with at least one 'model_id: str' argument."
)

if type(max_num_models_per_replica) is not int:
raise TypeError("max_num_models_per_replica must be an integer.")

if max_num_models_per_replica != -1 and max_num_models_per_replica <= 0:
raise ValueError("max_num_models_per_replica must be positive.")

def _multiplex_decorator(func: Callable):
@wraps(func)
async def _multiplex_wrapper(*args):
args_check_error_msg = (
"Functions decorated with `@serve.multiplexed` must take exactly one"
"the multiplexed model ID (str), but got {}"
)
if not args:
raise TypeError(
args_check_error_msg.format("no arguments are provided.")
)
self = extract_self_if_method_call(args, func)

# User defined multiplexed function can be a standalone function or a
# method of a class. If it is a method of a class, the first argument
# is self.
if self is None:
if len(args) != 1:
raise TypeError(
args_check_error_msg.format("more than one arguments.")
)
multiplex_object = func
model_id = args[0]
else:
# count self as an argument
if len(args) != 2:
raise TypeError(
args_check_error_msg.format("more than one arguments.")
)
multiplex_object = self
model_id = args[1]
multiplex_attr = f"__serve_multiplex_{func.__name__}"
# If the multiplexed function is called for the first time,
# create a model multiplex wrapper and cache it in the multiplex object.
if not hasattr(multiplex_object, multiplex_attr):
model_multiplex_wrapper = _ModelMultiplexWrapper(
func, self, max_num_models_per_replica
)
setattr(multiplex_object, multiplex_attr, model_multiplex_wrapper)
else:
model_multiplex_wrapper = getattr(multiplex_object, multiplex_attr)
return await model_multiplex_wrapper.load_model(model_id)

return _multiplex_wrapper

return _multiplex_decorator(func) if callable(func) else _multiplex_decorator


@PublicAPI(stability="alpha")
Expand Down Expand Up @@ -667,4 +742,5 @@ def get_multiplexed_model_id() -> str:
def my_deployment_function(request):
assert serve.get_multiplexed_model_id() == "model_1"
"""
raise NotImplementedError("get_multiplexed_model_id API is not supported yet.")
_request_context = ray.serve.context._serve_request_context.get()
return _request_context.multiplexed_model_id
27 changes: 2 additions & 25 deletions python/ray/serve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ray._private.signature import extract_signature, flatten_args, recover_args
from ray._private.utils import get_or_create_event_loop
from ray.serve.exceptions import RayServeException
from ray.serve._private.utils import extract_self_if_method_call
from ray.util.annotations import PublicAPI


Expand Down Expand Up @@ -171,30 +172,6 @@ def __del__(self):
self._handle_batch_task.cancel()


def _extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[object]:
"""Check if this is a method rather than a function.
Does this by checking to see if `func` is the attribute of the first
(`self`) argument under `func.__name__`. Unfortunately, this is the most
robust solution to this I was able to find. It would also be preferable
to do this check when the decorator runs, rather than when the method is.
Returns the `self` object if it's a method call, else None.
Arguments:
args (List[Any]): arguments to the function/method call.
func: the unbound function that was called.
"""
if len(args) > 0:
method = getattr(args[0], func.__name__, False)
if method:
wrapped = getattr(method, "__wrapped__", False)
if wrapped and wrapped == func:
return args[0]

return None


T = TypeVar("T")
R = TypeVar("R")
F = TypeVar("F", bound=Callable[[List[T]], List[R]])
Expand Down Expand Up @@ -289,7 +266,7 @@ async def __call__(self, request: Request):
def _batch_decorator(_func):
@wraps(_func)
async def batch_wrapper(*args, **kwargs):
self = _extract_self_if_method_call(args, _func)
self = extract_self_if_method_call(args, _func)
flattened_args: List = flatten_args(extract_signature(_func), args, kwargs)

if self is None:
Expand Down
1 change: 1 addition & 0 deletions python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class RequestContext:
route: str = ""
request_id: str = ""
app_name: str = ""
multiplexed_model_id: str = ""


_serve_request_context = contextvars.ContextVar(
Expand Down
96 changes: 96 additions & 0 deletions python/ray/serve/multiplex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from ray._private.async_compat import sync_to_async
from collections import OrderedDict
from typing import Any, Callable
import logging
from ray.serve._private.constants import SERVE_LOGGER_NAME
import inspect
import asyncio


logger = logging.getLogger(SERVE_LOGGER_NAME)


class _ModelMultiplexWrapper:
"""A wrapper class that wraps the model load function and
provides the LRU caching functionality.
The model multiplexer is a wrapper class that wraps the model load function
and provides the LRU caching functionality, and the model load function should
be a coroutine function that takes the model ID as the first argument and
returns the user-constructed model object.
The model multiplexer will also ensure that the number of models on the current
replica does not exceed the specified limit.
The model will be unloaded in the LRU order, the model multiplexer will call the
model's __del__ attribute if it exists to clean up the model resources eagerly.
"""

def __init__(
self,
model_load_func: Callable[[str], Any],
self_arg: Any,
max_num_models_per_replica: int,
):
"""Initialize the model multiplexer.
Args:
model_load_func: the model load async function.
self_arg: self argument when model_load_func is class method.
max_num_models_per_replica: the maximum number of models to be loaded on the
current replica. If it is -1, there is no limit for the number of models
per replica.
"""
self.models = OrderedDict()
self._func = model_load_func
self.self_arg = self_arg
self.max_num_models_per_replica = max_num_models_per_replica

async def load_model(self, model_id: str) -> Any:
"""Load the model if it is not loaded yet, and return the user-constructed model object.
Args:
model_id: the model ID.
Returns:
The user-constructed model object.
"""

if type(model_id) != str:
raise TypeError("The model ID must be a string.")

if not model_id:
raise ValueError("The model ID cannot be empty.")

if model_id in self.models:
# Move the model to the end of the OrderedDict to ensure LRU caching.
model = self.models.pop(model_id)
self.models[model_id] = model
else:
# If the number of models per replica is specified, check if the number of
# models on the current replica has reached the limit.
if (
self.max_num_models_per_replica > 0
and len(self.models) >= self.max_num_models_per_replica
):
# Unload the least recently used model.
await self.unload_model()
# Load the model.
logger.info(f"Loading model '{model_id}'.")
if self.self_arg is None:
self.models[model_id] = await self._func(model_id)
else:
self.models[model_id] = await self._func(self.self_arg, model_id)
return self.models[model_id]

async def unload_model(self) -> None:
"""Unload the least recently used model."""
model_id, model = self.models.popitem(last=False)
logger.info(f"Unloading model '{model_id}'.")

# If the model has __del__ attribute, call it.
# This is to clean up the model resources eagerly.
if hasattr(model, "__del__"):
if not inspect.iscoroutinefunction(model.__del__):
await asyncio.get_running_loop().run_in_executor(None, model.__del__)
else:
await sync_to_async(model.__del__)()
setattr(model, "__del__", lambda _: None)
Loading

0 comments on commit 6a5dfd9

Please sign in to comment.