Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Manager] Optimize watch instance deployment states implementation & Add enable_port_offset_store arg #92

Merged
merged 6 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--enable-pd-disagg]
[--num-dispatch-instances NUM_DISPATCH_INSTANCES]
[--enable-port-increment]
[--enable-port-offset-store]
```

`--host`
Expand Down Expand Up @@ -237,6 +238,9 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
`--enable-port-increment`
- Enable port increment when desploying multiple servers.

`--enable-port-offset-store`
- Enable store port offset when desploying multiple servers.

# Unsupported vLLM feature options

`--device`
Expand Down
8 changes: 8 additions & 0 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class ManagerArgs:
num_dispatch_instances: int = None

enable_port_increment: bool = None
enable_port_offset_store: bool = None


def __post_init__(self):
# Check if all fields default to None
Expand Down Expand Up @@ -222,6 +224,9 @@ def check_args(cls, args: 'ManagerArgs', parser: argparse.ArgumentParser):
assert not args.simulator_mode or args.profiling_result_file_path is not None, \
"Set profiling_result_file_path args when enable simulator mode"

assert not args.enable_port_offset_store or args.enable_port_increment, \
"Set enable_port_increment when enable_port_offset_store"

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--initial-instances',
Expand Down Expand Up @@ -357,6 +362,9 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--enable-port-increment',
action='store_true',
help='enable port increment when desploying multiple servers')
parser.add_argument('--enable-port-offset-store',
action='store_true',
help='enable store port offset when desploying multiple servers')

return parser

Expand Down
2 changes: 2 additions & 0 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
_C.MANAGER.PROFILING_RESULT_FILE_PATH = None
# Enable port increment when deploying multiple servers
_C.MANAGER.ENABLE_PORT_INCREMENT = False
# Enable store port offset when deploying multiple servers
_C.MANAGER.ENABLE_PORT_OFFSET_STORE = False

# -----------------------------------------------------------------------------
# DISPATCH CONFIGURATION
Expand Down
1 change: 0 additions & 1 deletion llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def get_all_request_ids(self) -> List[str]:
return self.backend_engine.get_all_request_ids()

def generate(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs) -> None:
# This should not be used for logging, as it is monotonic time.
if hasattr(server_info, 'request_timestamps'):
server_info.request_timestamps.llumlet_generate_timestamp = time.time()
self.backend_engine.add_request(request_id, server_info, expected_steps, *args, **kwargs)
Expand Down
58 changes: 39 additions & 19 deletions llumnix/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
AUTO_SCALE_UP_INTERVAL = 1.0
WAIT_PLACEMENT_GROUP_TIMEOUT = 5.0
CHECK_DEPLOYMENT_STATES_INTERVAL = 30.0
WATCH_DEPLOYMENT_INTERVAL = 40.0
WATCH_DEPLOYMENT_INTERVAL = 10.0
WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE = 120.0

# TODO(s5u13b): Handle exception of ray operations.
# TODO(s5u13b): Add exeception handling wrapper.
Expand Down Expand Up @@ -132,8 +133,12 @@ def __init__(self,
asyncio.create_task(self._update_instance_info_loop(self.polling_interval))
asyncio.create_task(self._clear_request_instance_loop(CLEAR_REQUEST_INSTANCE_INTERVAL))

value = get_actor_data_from_ray_internal_kv("manager", "port_offset")
self.port_offset = 0 if value is None else int(value)
if self.manager_args.enable_port_increment:
self.port_offset = 0
if self.manager_args.enable_port_offset_store:
value = get_actor_data_from_ray_internal_kv("manager", "port_offset")
if value is not None:
self.port_offset = int(value)
if hasattr(self, "launch_mode") and self.launch_mode == LaunchMode.GLOBAL:
assert self.entrypoints_args is not None and self.engine_args is not None
self.last_timeout_instance_id = None
Expand Down Expand Up @@ -312,7 +317,7 @@ async def _auto_scale_up_loop(self, interval: float) -> None:
try:
await asyncio.wait_for(new_pg.ready(), WAIT_PLACEMENT_GROUP_TIMEOUT)
except asyncio.TimeoutError:
logger.info("[_auto_scale_up_loop] waiting for new placement group ready timeout")
logger.debug("[_auto_scale_up_loop] waiting for new placement group ready timeout")
# After timeout, the new placement group might be pending,
# created(without server and instance), rescheduling.
self.last_timeout_instance_id = new_instance_id
Expand Down Expand Up @@ -430,7 +435,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migration_b
no_pending_instance = self.pending_rebuild_migration_instances == 0

for ins_id in instance_ids:
self._clear_instance_ray_resources(ins_id)
self._clear_instance_ray_states(ins_id)
if ins_id in self.instances:
indeed_update = True
if ins_id in self.instances:
Expand Down Expand Up @@ -460,7 +465,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migration_b

return self.num_instances

def _clear_instance_ray_resources(self, instance_id: str):
def _clear_instance_ray_states(self, instance_id: str):
if not remove_placement_group(instance_id):
logger.debug("[clear_instance_ray_resources] failed to remove placement group {}".format(instance_id))
if not kill_server(instance_id):
Expand Down Expand Up @@ -544,7 +549,8 @@ def _init_server(self,
entrypoints_args.port += self.port_offset
entrypoints_args.request_output_queue_port += self.port_offset
self.port_offset += 1
put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset)
if self.manager_args.enable_port_offset_store:
put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset)
fastapi_server = FastAPIServerActor.from_args(server_name, placement_group, entrypoints_args)
return fastapi_server

Expand Down Expand Up @@ -605,23 +611,27 @@ async def done_scale_up():
asyncio.create_task(done_scale_up())

async def _check_deployment_states_loop(self, interval: float) -> None:
async def watch_deployment(instance_id: str):
async def watch_instance_deployment_states(instance_id: str):
# There might be some delays of calling _init_server_and_instance, so sleep first.
await asyncio.sleep(WATCH_DEPLOYMENT_INTERVAL)
curr_pgs, curr_servers, curr_instances = self.get_curr_deployment()
if instance_id in curr_pgs and (instance_id not in curr_servers or instance_id not in curr_instances):
logger.warning("[_check_deployment_states_loop] instance {} deployment states incorrect, "
"states: (pg {}, server {}, instance {})"
.format(instance_id, instance_id in curr_pgs, instance_id in curr_servers, instance_id in curr_instances))
instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))])
instance_pending_creation = len(instance_state) == 1 and instance_state[0]["state"] == "PENDING_CREATION"
if instance_pending_creation:
await asyncio.sleep(WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE)
pg_created, server_alive, instance_alive = self._get_instance_deployment_states(instance_id)
if pg_created and (not server_alive or not instance_alive):
logger.warning("instance {} deployment states incorrect, states: (pg {}, server {}, instance {})"
.format(instance_id, pg_created, server_alive, instance_alive))
self.scale_down(instance_id)

while True:
try:
curr_pgs, curr_servers, curr_instances = self.get_curr_deployment()
curr_pgs, curr_servers, curr_instances = self._get_cluster_deployment()
assert len(curr_pgs) >= max(len(curr_servers), len(curr_instances))
tasks = []
for instance_id in curr_pgs:
if instance_id not in curr_servers or instance_id not in curr_instances:
tasks.append(asyncio.create_task(watch_deployment(instance_id)))
tasks.append(asyncio.create_task(watch_instance_deployment_states(instance_id)))
await asyncio.gather(*tasks, return_exceptions=True)
await asyncio.sleep(interval)
# pylint: disable=broad-except
Expand Down Expand Up @@ -655,7 +665,7 @@ def check_instance_error_done_callback(idx: int, instance_id: str, fut):

return results

def get_curr_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, FastAPIServerActor], Dict[str, Llumlet]]:
def _get_cluster_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, FastAPIServerActor], Dict[str, Llumlet]]:
curr_pgs: Dict[str, PlacementGroup] = {}
curr_servers: Dict[str, PlacementGroup] = {}
curr_instances: Dict[str, Llumlet] = {}
Expand All @@ -676,6 +686,16 @@ def get_curr_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, Fast

return curr_pgs, curr_servers, curr_instances

def _get_instance_deployment_states(self, instance_id: str):
pg_state = list_placement_groups(filters=[("name", "=", get_placement_group_name(instance_id))])
pg_created = len(pg_state) == 1 and pg_state[0]["state"] == "CREATED"
server_state = list_actors(filters=[("name", "=", get_server_name(instance_id))])
server_alive = len(server_state) == 1 and server_state[0]["state"] == "ALIVE"
instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))])
instance_alive = len(instance_state) == 1 and instance_state[0]["state"] == "ALIVE"

return pg_created, server_alive, instance_alive

async def _get_request_instance(self) -> None:
def get_request_instance_done_callback(instance_id: str, fut):
ret = fut.result()[0]
Expand All @@ -690,12 +710,12 @@ def get_request_instance_done_callback(instance_id: str, fut):
instance_ids = []
tasks = []
for instance_id, instance_actor_handle in self.instances.items():
task = asyncio.gather(instance_actor_handle.get_instance_info.remote(), return_exceptions=True)
task = asyncio.gather(instance_actor_handle.get_all_request_ids.remote(), return_exceptions=True)
task.add_done_callback(partial(get_request_instance_done_callback, instance_id))
tasks.append(task)
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug("[_get_request_instance] instance_ids: {}".format(instance_ids))
logger.debug("[_get_request_instance] instance_requests: {}".format(instance_requests))
logger.info("instance_ids: {}".format(instance_ids))
logger.info("instance_requests: {}".format(instance_requests))
for (instance_id, requests) in zip(instance_ids, instance_requests):
for request_id in requests:
self.request_instance[request_id] = instance_id
Expand Down
24 changes: 15 additions & 9 deletions tests/unit_test/global_scheduler/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_update_instance_info_loop_and_migrate(ray_env, manager):
else:
assert num_migrate_in == 0 and num_migrate_out == 0

def test_init_server_and_instance_and_clear_instance_ray_resources(ray_env):
def test_init_server_and_get_instance_deployment_states_and_instance_and_clear_instance_ray_resources(ray_env):
manager, _, _, engine_args, _ = init_manager_with_launch_mode(LaunchMode.LOCAL)
instance_id = random_uuid()
pg = ray.get(manager._init_placement_group.remote(get_placement_group_name(instance_id),
Expand All @@ -333,8 +333,11 @@ def test_init_server_and_instance_and_clear_instance_ray_resources(ray_env):
num_instances = ray.get(manager.scale_up.remote(instance_id, instance))
assert num_instances == 1

pg_created, server_alive, instance_alive = ray.get(manager._get_instance_deployment_states.remote(instance_id))
assert pg_created and server_alive and instance_alive

# test clear_instance_ray_resources
ray.get(manager._clear_instance_ray_resources.remote(instance_id))
ray.get(manager._clear_instance_ray_states.remote(instance_id))
# wait for remove and kill
time.sleep(1.0)
pg_exists = is_placement_group_exists(get_placement_group_name(instance_id))
Expand All @@ -344,25 +347,28 @@ def test_init_server_and_instance_and_clear_instance_ray_resources(ray_env):
instance_exists = is_actor_exists(get_instance_name(instance_id))
assert not instance_exists

pg_created, server_alive, instance_alive = ray.get(manager._get_instance_deployment_states.remote(instance_id))
assert not pg_created and not server_alive and not instance_alive

@pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq'])
def test_auto_scale_up_loop_and_get_curr_deployment(ray_env, request_output_queue_type):
def test_auto_scale_up_loop_and_get_cluster_deployment(ray_env, request_output_queue_type):
manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, request_output_queue_type)
time.sleep(30.0)
num_instances = ray.get(manager.scale_up.remote([], []))
assert num_instances == 4
curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote())
curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote())
assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4

actor_names_dict = ray.util.list_named_actors(all_namespaces=True)
instance_ids = [actor_name_dict['name'].split("_")[-1] for actor_name_dict in actor_names_dict
if actor_name_dict['name'].startswith(INSTANCE_NAME_PREFIX)]
assert len(instance_ids) == 4
ray.get(manager._clear_instance_ray_resources.remote(instance_ids[0]))
ray.get(manager._clear_instance_ray_resources.remote(instance_ids[1]))
ray.get(manager._clear_instance_ray_states.remote(instance_ids[0]))
ray.get(manager._clear_instance_ray_states.remote(instance_ids[1]))
time.sleep(30.0)
num_instances = ray.get(manager.scale_up.remote([], []))
assert num_instances == 4
curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote())
curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote())
assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4

@pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq'])
Expand All @@ -371,7 +377,7 @@ def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, request_ou
time.sleep(30.0)
num_instances = ray.get(manager.scale_up.remote([], []))
assert num_instances == 4
curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote())
curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote())
assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4

actor_names_dict = ray.util.list_named_actors(all_namespaces=True)
Expand All @@ -385,5 +391,5 @@ def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, request_ou
time.sleep(120.0)
num_instances = ray.get(manager.scale_up.remote([], []))
assert num_instances == 4
curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote())
curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote())
assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4
Loading