Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Multiplex API Impl #35326

Merged
merged 5 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion python/ray/serve/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
import traceback
from enum import Enum
from functools import wraps
from typing import Dict, Iterable, List, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Tuple,
TypeVar,
Union,
Optional,
)

import fastapi.encoders
import numpy as np
Expand Down Expand Up @@ -535,3 +545,27 @@ def record_serve_tag(key: str, value: str):
)

record_extra_usage_tag(serve_telemetry_tag_map[key], value)


def _extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[object]:
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved
"""Check if this is a method rather than a function.

Does this by checking to see if `func` is the attribute of the first
(`self`) argument under `func.__name__`. Unfortunately, this is the most
robust solution to this I was able to find. It would also be preferable
to do this check when the decorator runs, rather than when the method is.

Returns the `self` object if it's a method call, else None.

Arguments:
args (List[Any]): arguments to the function/method call.
func: the unbound function that was called.
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved
"""
if len(args) > 0:
method = getattr(args[0], func.__name__, False)
if method:
wrapped = getattr(method, "__wrapped__", False)
if wrapped and wrapped == func:
return args[0]

return None
72 changes: 70 additions & 2 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import inspect
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union
from functools import wraps

from fastapi import APIRouter, FastAPI
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from starlette.requests import Request
from uvicorn.config import Config
from uvicorn.lifespan.on import LifespanOn

import ray
from ray import cloudpickle
from ray.dag import DAGNode
from ray.util.annotations import Deprecated, PublicAPI
Expand All @@ -29,6 +31,7 @@
_set_global_client,
)
from ray.serve.deployment import Application, Deployment
from ray.serve.multiplex import _ModelMultiplexWrapper
from ray.serve._private.deployment_graph_build import build as pipeline_build
from ray.serve._private.deployment_graph_build import (
get_and_validate_ingress_deployment,
Expand All @@ -45,6 +48,7 @@
install_serve_encoders_to_fastapi,
guarded_deprecation_warning,
record_serve_tag,
_extract_self_if_method_call,
)

from ray.serve._private import api as _private_api
Expand Down Expand Up @@ -638,7 +642,70 @@ async def __call__(self, request):
number if you want to save memory on the node resource.
"""

raise NotImplementedError("Multiplexed deployment is not supported yet.")
if func is not None:
if not callable(func):
raise TypeError(
"The `multiplexed` decorator must be used with a function or method."
)

if not inspect.iscoroutinefunction(func):
raise TypeError(
"@serve.multiplexed can only be used to decorate async "
"functions or methods."
)
Comment on lines +653 to +657
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should remove this requirement and post sync callables to an executor (can be a follow-up PR)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would like to do the same for batch handler

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user use sync function.

@serve.multiplexed
def load_model(mode_id:str):
    return

@serve.deployment
class Model:
    async def __call__(self, req):
        model = await load_model(req.meta.model_id)

User define a sync function, but he has to still use await in this deployment code. This might cause confusion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I agree it is a little bit of a funny API... but probably better than the alternative where people will just slap async in front of their function and actually block the loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave it to as a followup :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok please leave TODO + link to a github issue tracking it

signature = inspect.signature(func)
if len(signature.parameters) == 0 or len(signature.parameters) > 2:
raise TypeError(
"@serve.multiplexed can only be used to decorate functions or methods "
"with at least one 'model_id: str' argument."
)
edoakes marked this conversation as resolved.
Show resolved Hide resolved

if type(max_num_models_per_replica) is not int:
raise TypeError("max_num_models_per_replica must be an integer.")

if max_num_models_per_replica != -1 and max_num_models_per_replica <= 0:
raise ValueError("max_num_models_per_replica must be positive.")

def _multiplex_decorator(func):
@wraps(func)
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved
async def _multiplex_wrapper(*args):
args_check_error_msg = (
"The args of the multiplexed function must have at least one "
"argument with type `str` as model_id, but got {}"
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved
)
if not args:
raise TypeError(
args_check_error_msg.format("no arguments are provided.")
)
self = _extract_self_if_method_call(args, func)
if self is None:
if len(args) != 1:
raise TypeError(
args_check_error_msg.format("more than one arguments.")
)
multiplex_object = func
model_id = args[0]
else:
# count self as an argument
if len(args) != 2:
raise TypeError(
args_check_error_msg.format("more than one arguments.")
)
multiplex_object = self
model_id = args[1]
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved
multiplex_attr = f"__serve_multiplex_{func.__name__}"
if not hasattr(multiplex_object, multiplex_attr):
model_multiplex_wrapper = _ModelMultiplexWrapper(
func, self, max_num_models_per_replica
)
setattr(multiplex_object, multiplex_attr, model_multiplex_wrapper)
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved
else:
model_multiplex_wrapper = getattr(multiplex_object, multiplex_attr)
return await model_multiplex_wrapper.load_model(model_id)

return _multiplex_wrapper

return _multiplex_decorator(func) if callable(func) else _multiplex_decorator


@PublicAPI(stability="alpha")
Expand Down Expand Up @@ -667,4 +734,5 @@ def get_multiplexed_model_id() -> str:
def my_deployment_function(request):
assert serve.get_multiplexed_model_id() == "model_1"
"""
raise NotImplementedError("get_multiplexed_model_id API is not supported yet.")
_request_context = ray.serve.context._serve_request_context.get()
return _request_context.multiplexed_model_id
25 changes: 1 addition & 24 deletions python/ray/serve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ray._private.signature import extract_signature, flatten_args, recover_args
from ray._private.utils import get_or_create_event_loop
from ray.serve.exceptions import RayServeException
from ray.serve._private.utils import _extract_self_if_method_call
from ray.util.annotations import PublicAPI


Expand Down Expand Up @@ -171,30 +172,6 @@ def __del__(self):
self._handle_batch_task.cancel()


def _extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[object]:
"""Check if this is a method rather than a function.

Does this by checking to see if `func` is the attribute of the first
(`self`) argument under `func.__name__`. Unfortunately, this is the most
robust solution to this I was able to find. It would also be preferable
to do this check when the decorator runs, rather than when the method is.

Returns the `self` object if it's a method call, else None.

Arguments:
args (List[Any]): arguments to the function/method call.
func: the unbound function that was called.
"""
if len(args) > 0:
method = getattr(args[0], func.__name__, False)
if method:
wrapped = getattr(method, "__wrapped__", False)
if wrapped and wrapped == func:
return args[0]

return None


T = TypeVar("T")
R = TypeVar("R")
F = TypeVar("F", bound=Callable[[List[T]], List[R]])
Expand Down
1 change: 1 addition & 0 deletions python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class RequestContext:
route: str = ""
request_id: str = ""
app_name: str = ""
multiplexed_model_id: str = ""


_serve_request_context = contextvars.ContextVar(
Expand Down
90 changes: 90 additions & 0 deletions python/ray/serve/multiplex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from ray._private.async_compat import sync_to_async
from collections import OrderedDict
from typing import Any, Callable
import logging
from ray.serve._private.constants import SERVE_LOGGER_NAME
import inspect
import asyncio


logger = logging.getLogger(SERVE_LOGGER_NAME)


class _ModelMultiplexWrapper:
def __init__(
self,
model_load_func: Callable[[str], Any],
self_arg: Any,
max_num_models_per_replica: int,
):
"""Initialize the model multiplexer.

The model multiplexer is a wrapper class that wraps the model load function
and provides the LRU caching functionality, and the model load function should
be a coroutine function that takes the model ID as the first argument and
returns the model handle.
The model multiplexer will also ensure that the number of models on the current
replica does not exceed the specified limit.
The model will be unloaded in the LRU order, the model multiplexer will call the
model's __del__ attribute if it exists to clean up the model resources eagerly.
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved

Args:
model_load_func: the model load async function.
self_arg: self argument when model_load_func is class method.
max_num_models_per_replica: the maximum number of models to be loaded on the
current replica. If it is -1, there is no limit for the number of models
per replica.
"""
self.models = OrderedDict()
self._func = model_load_func
self.self_arg = self_arg
self.max_num_models_per_replica = max_num_models_per_replica

async def load_model(self, model_id: str) -> Any:
"""Load the model if it is not loaded yet, and return the model handle.

Args:
model_id: the model ID.

Returns:
The model handle.
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved
"""

if type(model_id) != str:
raise TypeError("The model ID must be a string.")

if not model_id:
raise ValueError("The model ID cannot be empty.")

if model_id in self.models:
# Move the model to the end of the OrderedDict to ensure LRU caching.
model = self.models.pop(model_id)
self.models[model_id] = model
else:
# If the number of models per replica is specified, check if the number of
# models on the current replica has reached the limit.
if self.max_num_models_per_replica > 0:
edoakes marked this conversation as resolved.
Show resolved Hide resolved
if len(self.models) >= self.max_num_models_per_replica:
# Unload the least recently used model.
await self.unload_model()
# Load the model.
logger.info("Loading model '{}'.".format(model_id))
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved
if self.self_arg is None:
self.models[model_id] = await self._func(model_id)
else:
self.models[model_id] = await self._func(self.self_arg, model_id)
return self.models[model_id]

async def unload_model(self) -> None:
"""Unload the least recently used model."""
model_id, model = self.models.popitem(last=False)
logger.info("Unloading model '{}'.".format(model_id))
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved

# If the model has __del__ attribute, call it.
# This is to clean up the model resources eagerly.
Comment on lines +89 to +90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we need to overwrite the __del__ attribute after doing this so that it isn't called twice (which might have some weird side effects)

if hasattr(model, "__del__"):
if not inspect.iscoroutinefunction(model.__del__):
await asyncio.get_running_loop().run_in_executor(None, model.__del__)
else:
await sync_to_async(model.__del__)()
setattr(model, "__del__", lambda _: None)
Loading