diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index 456f57e8ae28..c6d4a6797092 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -9,7 +9,17 @@ import traceback from enum import Enum from functools import wraps -from typing import Dict, Iterable, List, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Tuple, + TypeVar, + Union, + Optional, +) import fastapi.encoders import numpy as np @@ -535,3 +545,27 @@ def record_serve_tag(key: str, value: str): ) record_extra_usage_tag(serve_telemetry_tag_map[key], value) + + +def extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[object]: + """Check if this is a method rather than a function. + + Does this by checking to see if `func` is the attribute of the first + (`self`) argument under `func.__name__`. Unfortunately, this is the most + robust solution to this I was able to find. It would also be preferable + to do this check when the decorator runs, rather than when the method is. + + Returns the `self` object if it's a method call, else None. + + Arguments: + args: arguments to the function/method call. + func: the unbound function that was called. + """ + if len(args) > 0: + method = getattr(args[0], func.__name__, False) + if method: + wrapped = getattr(method, "__wrapped__", False) + if wrapped and wrapped == func: + return args[0] + + return None diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 737a62b26b75..c86d8b749f36 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -2,6 +2,7 @@ import inspect import logging from typing import Any, Callable, Dict, Optional, Tuple, Union +from functools import wraps from fastapi import APIRouter, FastAPI from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag @@ -9,6 +10,7 @@ from uvicorn.config import Config from uvicorn.lifespan.on import LifespanOn +import ray from ray import cloudpickle from ray.dag import DAGNode from ray.util.annotations import Deprecated, PublicAPI @@ -29,6 +31,7 @@ _set_global_client, ) from ray.serve.deployment import Application, Deployment +from ray.serve.multiplex import _ModelMultiplexWrapper from ray.serve._private.deployment_graph_build import build as pipeline_build from ray.serve._private.deployment_graph_build import ( get_and_validate_ingress_deployment, @@ -45,6 +48,7 @@ install_serve_encoders_to_fastapi, guarded_deprecation_warning, record_serve_tag, + extract_self_if_method_call, ) from ray.serve._private import api as _private_api @@ -638,7 +642,78 @@ async def __call__(self, request): number if you want to save memory on the node resource. """ - raise NotImplementedError("Multiplexed deployment is not supported yet.") + if func is not None: + if not callable(func): + raise TypeError( + "The `multiplexed` decorator must be used with a function or method." + ) + + # TODO(Sihan): Make the API accept the sync function as well. + # https://github.com/ray-project/ray/issues/35356 + if not inspect.iscoroutinefunction(func): + raise TypeError( + "@serve.multiplexed can only be used to decorate async " + "functions or methods." + ) + signature = inspect.signature(func) + if len(signature.parameters) == 0 or len(signature.parameters) > 2: + raise TypeError( + "@serve.multiplexed can only be used to decorate functions or methods " + "with at least one 'model_id: str' argument." + ) + + if type(max_num_models_per_replica) is not int: + raise TypeError("max_num_models_per_replica must be an integer.") + + if max_num_models_per_replica != -1 and max_num_models_per_replica <= 0: + raise ValueError("max_num_models_per_replica must be positive.") + + def _multiplex_decorator(func: Callable): + @wraps(func) + async def _multiplex_wrapper(*args): + args_check_error_msg = ( + "Functions decorated with `@serve.multiplexed` must take exactly one" + "the multiplexed model ID (str), but got {}" + ) + if not args: + raise TypeError( + args_check_error_msg.format("no arguments are provided.") + ) + self = extract_self_if_method_call(args, func) + + # User defined multiplexed function can be a standalone function or a + # method of a class. If it is a method of a class, the first argument + # is self. + if self is None: + if len(args) != 1: + raise TypeError( + args_check_error_msg.format("more than one arguments.") + ) + multiplex_object = func + model_id = args[0] + else: + # count self as an argument + if len(args) != 2: + raise TypeError( + args_check_error_msg.format("more than one arguments.") + ) + multiplex_object = self + model_id = args[1] + multiplex_attr = f"__serve_multiplex_{func.__name__}" + # If the multiplexed function is called for the first time, + # create a model multiplex wrapper and cache it in the multiplex object. + if not hasattr(multiplex_object, multiplex_attr): + model_multiplex_wrapper = _ModelMultiplexWrapper( + func, self, max_num_models_per_replica + ) + setattr(multiplex_object, multiplex_attr, model_multiplex_wrapper) + else: + model_multiplex_wrapper = getattr(multiplex_object, multiplex_attr) + return await model_multiplex_wrapper.load_model(model_id) + + return _multiplex_wrapper + + return _multiplex_decorator(func) if callable(func) else _multiplex_decorator @PublicAPI(stability="alpha") @@ -667,4 +742,5 @@ def get_multiplexed_model_id() -> str: def my_deployment_function(request): assert serve.get_multiplexed_model_id() == "model_1" """ - raise NotImplementedError("get_multiplexed_model_id API is not supported yet.") + _request_context = ray.serve.context._serve_request_context.get() + return _request_context.multiplexed_model_id diff --git a/python/ray/serve/batching.py b/python/ray/serve/batching.py index a44f918bc079..219e954a6831 100644 --- a/python/ray/serve/batching.py +++ b/python/ray/serve/batching.py @@ -9,6 +9,7 @@ from ray._private.signature import extract_signature, flatten_args, recover_args from ray._private.utils import get_or_create_event_loop from ray.serve.exceptions import RayServeException +from ray.serve._private.utils import extract_self_if_method_call from ray.util.annotations import PublicAPI @@ -171,30 +172,6 @@ def __del__(self): self._handle_batch_task.cancel() -def _extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[object]: - """Check if this is a method rather than a function. - - Does this by checking to see if `func` is the attribute of the first - (`self`) argument under `func.__name__`. Unfortunately, this is the most - robust solution to this I was able to find. It would also be preferable - to do this check when the decorator runs, rather than when the method is. - - Returns the `self` object if it's a method call, else None. - - Arguments: - args (List[Any]): arguments to the function/method call. - func: the unbound function that was called. - """ - if len(args) > 0: - method = getattr(args[0], func.__name__, False) - if method: - wrapped = getattr(method, "__wrapped__", False) - if wrapped and wrapped == func: - return args[0] - - return None - - T = TypeVar("T") R = TypeVar("R") F = TypeVar("F", bound=Callable[[List[T]], List[R]]) @@ -289,7 +266,7 @@ async def __call__(self, request: Request): def _batch_decorator(_func): @wraps(_func) async def batch_wrapper(*args, **kwargs): - self = _extract_self_if_method_call(args, _func) + self = extract_self_if_method_call(args, _func) flattened_args: List = flatten_args(extract_signature(_func), args, kwargs) if self is None: diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py index 902025b0adcc..0823ceac2d88 100644 --- a/python/ray/serve/context.py +++ b/python/ray/serve/context.py @@ -149,6 +149,7 @@ class RequestContext: route: str = "" request_id: str = "" app_name: str = "" + multiplexed_model_id: str = "" _serve_request_context = contextvars.ContextVar( diff --git a/python/ray/serve/multiplex.py b/python/ray/serve/multiplex.py new file mode 100644 index 000000000000..0cfed15bf62d --- /dev/null +++ b/python/ray/serve/multiplex.py @@ -0,0 +1,96 @@ +from ray._private.async_compat import sync_to_async +from collections import OrderedDict +from typing import Any, Callable +import logging +from ray.serve._private.constants import SERVE_LOGGER_NAME +import inspect +import asyncio + + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +class _ModelMultiplexWrapper: + """A wrapper class that wraps the model load function and + provides the LRU caching functionality. + + The model multiplexer is a wrapper class that wraps the model load function + and provides the LRU caching functionality, and the model load function should + be a coroutine function that takes the model ID as the first argument and + returns the user-constructed model object. + The model multiplexer will also ensure that the number of models on the current + replica does not exceed the specified limit. + The model will be unloaded in the LRU order, the model multiplexer will call the + model's __del__ attribute if it exists to clean up the model resources eagerly. + + """ + + def __init__( + self, + model_load_func: Callable[[str], Any], + self_arg: Any, + max_num_models_per_replica: int, + ): + """Initialize the model multiplexer. + Args: + model_load_func: the model load async function. + self_arg: self argument when model_load_func is class method. + max_num_models_per_replica: the maximum number of models to be loaded on the + current replica. If it is -1, there is no limit for the number of models + per replica. + """ + self.models = OrderedDict() + self._func = model_load_func + self.self_arg = self_arg + self.max_num_models_per_replica = max_num_models_per_replica + + async def load_model(self, model_id: str) -> Any: + """Load the model if it is not loaded yet, and return the user-constructed model object. + + Args: + model_id: the model ID. + + Returns: + The user-constructed model object. + """ + + if type(model_id) != str: + raise TypeError("The model ID must be a string.") + + if not model_id: + raise ValueError("The model ID cannot be empty.") + + if model_id in self.models: + # Move the model to the end of the OrderedDict to ensure LRU caching. + model = self.models.pop(model_id) + self.models[model_id] = model + else: + # If the number of models per replica is specified, check if the number of + # models on the current replica has reached the limit. + if ( + self.max_num_models_per_replica > 0 + and len(self.models) >= self.max_num_models_per_replica + ): + # Unload the least recently used model. + await self.unload_model() + # Load the model. + logger.info(f"Loading model '{model_id}'.") + if self.self_arg is None: + self.models[model_id] = await self._func(model_id) + else: + self.models[model_id] = await self._func(self.self_arg, model_id) + return self.models[model_id] + + async def unload_model(self) -> None: + """Unload the least recently used model.""" + model_id, model = self.models.popitem(last=False) + logger.info(f"Unloading model '{model_id}'.") + + # If the model has __del__ attribute, call it. + # This is to clean up the model resources eagerly. + if hasattr(model, "__del__"): + if not inspect.iscoroutinefunction(model.__del__): + await asyncio.get_running_loop().run_in_executor(None, model.__del__) + else: + await sync_to_async(model.__del__)() + setattr(model, "__del__", lambda _: None) diff --git a/python/ray/serve/tests/test_multiplex.py b/python/ray/serve/tests/test_multiplex.py index c1eb44d2c066..c435c1523925 100644 --- a/python/ray/serve/tests/test_multiplex.py +++ b/python/ray/serve/tests/test_multiplex.py @@ -1,25 +1,160 @@ import pytest +import ray from ray import serve +from ray.serve.multiplex import _ModelMultiplexWrapper -def test_multiplexed(): - """Test multiplexed API.""" +class TestMultiplexWrapper: + @pytest.mark.asyncio + async def test_multiplex_wrapper(self): + """Test multiplex wrapper with LRU caching.""" - with pytest.raises(NotImplementedError): + async def model_load_func(model_id: str): + return model_id + + multiplexer = _ModelMultiplexWrapper( + model_load_func, None, max_num_models_per_replica=2 + ) + # Load model1 + await multiplexer.load_model("1") + assert multiplexer.models == {"1": "1"} + # Load model2 + await multiplexer.load_model("2") + assert multiplexer.models == {"1": "1", "2": "2"} + + # Load model3, model1 should be unloaded + await multiplexer.load_model("3") + assert multiplexer.models == {"2": "2", "3": "3"} + + # reload model2, model2 should be moved to the end of the LRU cache + await multiplexer.load_model("2") + assert multiplexer.models == {"3": "3", "2": "2"} + + # Load model4, model3 should be unloaded + await multiplexer.load_model("4") + assert multiplexer.models == {"2": "2", "4": "4"} + + @pytest.mark.asyncio + async def test_bad_call_multiplexed_func(self): + """Test bad call to multiplexed function""" + + async def model_load_func(model_id: str): + return model_id + + multiplexer = _ModelMultiplexWrapper( + model_load_func, None, max_num_models_per_replica=2 + ) + with pytest.raises(TypeError): + await multiplexer.load_model(1) + with pytest.raises(TypeError): + await multiplexer.load_model() + + @pytest.mark.asyncio + async def test_unload_model_call_del(self): + class MyModel: + def __init__(self, model_id): + self.model_id = model_id + + def __del__(self): + raise Exception(f"{self.model_id} is dead") + + def __eq__(self, model): + return model.model_id == self.model_id + + async def model_load_func(model_id: str) -> MyModel: + return MyModel(model_id) + + multiplexer = _ModelMultiplexWrapper( + model_load_func, None, max_num_models_per_replica=1 + ) + await multiplexer.load_model("1") + assert multiplexer.models == {"1": MyModel("1")} + with pytest.raises(Exception, match="1 is dead"): + await multiplexer.load_model("2") + + +class TestBasicAPI: + def test_decorator_validation(self): + @serve.multiplexed + async def get_model(model: str): + return + + @serve.multiplexed(max_num_models_per_replica=1) + async def get_model2(model: str): + return + + @serve.deployment + class MyModel: + @serve.multiplexed + async def get_model(model: str): + return @serve.deployment - class Model: + class MyModel2: + @serve.multiplexed(max_num_models_per_replica=1) + async def get_model(self, model: str): + return + + # multiplex can only be used with func or method. + with pytest.raises(TypeError): + + @serve.deployment + @serve.multiplexed + class BadDecorator: + pass + + # max_num_models_per_replica must be an integer + with pytest.raises(TypeError): + + @serve.multiplexed(max_num_models_per_replica="1") + async def get_model3(model: str): + pass + + # max_num_models_per_replica must be positive + with pytest.raises(ValueError): + + @serve.multiplexed(max_num_models_per_replica=0) + async def get_model4(model: str): + pass + + # multiplexed function must be async def + with pytest.raises(TypeError): + + @serve.multiplexed + def get_model5(model: str): + pass + + with pytest.raises(TypeError): + + @serve.deployment + class MyModel3: + @serve.multiplexed + def get_model(self, model: str): + return + + # no model_id argument in multiplexed function + with pytest.raises(TypeError): + @serve.multiplexed - def get_model(self, model_id: str): + def get_model6(): pass + with pytest.raises(TypeError): -def test_get_multiplexed_model_id(): - """Test get_multiplexed_model_id API.""" + @serve.deployment + class MyModel4: + @serve.multiplexed + def get_model(self): + return - with pytest.raises(NotImplementedError): - serve.get_multiplexed_model_id() + def test_get_multiplexed_model_id(self): + """Test get_multiplexed_model_id() API""" + assert serve.get_multiplexed_model_id() == "" + ray.serve.context._serve_request_context.set( + ray.serve.context.RequestContext(multiplexed_model_id="1") + ) + assert serve.get_multiplexed_model_id() == "1" if __name__ == "__main__":