From 35e425946a4dcb7ea3cdbcf6f1c0dfa39fe6f3d0 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 23 Oct 2024 09:16:14 +0000 Subject: [PATCH] Fix status bugs --- llumnix/backends/vllm/llm_engine.py | 13 ++++++++---- llumnix/backends/vllm/scheduler.py | 20 ++++++++++--------- llumnix/backends/vllm/sequence.py | 8 +++++--- llumnix/llumlet/llumlet.py | 18 ++++++++++------- llumnix/llumlet/request.py | 11 +++++++++- tests/e2e_test/test_migration.py | 4 ++-- .../unit_test/backends/vllm/test_migration.py | 2 ++ .../unit_test/backends/vllm/test_scheduler.py | 6 +++--- 8 files changed, 53 insertions(+), 29 deletions(-) diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 163a65e5..f8c1a620 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -353,11 +353,16 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None: pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id) self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id) backend_request.reset_migration_args_dst() - if backend_request.status == RequestStatus.RUNNING: - self.add_running_request(backend_request) - else: # RequestStatus.WAITING - backend_request.waiting_migrating = True + assert backend_request.status in [RequestStatus.WAITING_MIGRATING, RequestStatus.RUNNING_MIGRATING], \ + "The status of request migrated to dst instance should be \ + RequestStatus.WAITING_MIGRATING or RequestStatus.RUNNING_MIGRATING" + if backend_request.status == RequestStatus.WAITING_MIGRATING: + self.engine.scheduler.set_status(backend_request, status_to=SequenceStatus.WAITING) self.add_waiting_request(backend_request) + elif backend_request.status == RequestStatus.RUNNING_MIGRATING: + backend_request.reset_status() + self.engine.scheduler.set_status(backend_request, status_to=SequenceStatus.RUNNING) + self.add_running_request(backend_request) async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: await dst_ray_actor.execute_engine_method.remote("_run_workers", diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index 3b7f7ed1..42d83760 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -103,6 +103,7 @@ def remove_running_request(self, request_id: str) -> bool: for seq_group in self.running: if seq_group.request_id == request_id: self.running.remove(seq_group) + seq_group.set_status(RequestStatus.RUNNING_MIGRATING) return True return False @@ -110,6 +111,7 @@ def remove_waiting_request(self, request_id: str) -> bool: for seq_group in self.waiting: if seq_group.request_id == request_id: self.waiting.remove(seq_group) + seq_group.set_status(RequestStatus.WAITING_MIGRATING) return True return False @@ -131,7 +133,7 @@ def pre_alloc(self, block_num: int) -> List[int]: # Only migrate waiting request when the waiting request is the earliest arrival one # among the requests of dst instance's waiting queue. - if request_status == RequestStatus.WAITING: + if request_status == RequestStatus.WAITING_MIGRATING: if (self.waiting and request_arrival_time > self.waiting[0].arrival_time) \ or block_num * self.cache_config.block_size > self.prompt_limit: return [] @@ -152,23 +154,23 @@ def add_waiting_request(self, backend_request: LlumnixRequest) -> None: self.waiting = fcfs_policy.sort_by_priority(time.time(), self.waiting) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: - if seq_group.waiting_migrating: + if seq_group.status == RequestStatus.WAITING_MIGRATING: return AllocStatus.OK return super().can_allocate(seq_group) def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: # Change seq status to running, but request status is still waiting_migrating. - if seq_group.waiting_migrating: + if seq_group.status == RequestStatus.WAITING_MIGRATING: # For the waiting request migrated in, blocks have already been allocated when pre alloc. - self._set_status(seq_group, status_to=SequenceStatus.RUNNING) - seq_group.waiting_migrating = False + self.set_status(seq_group, status_to=SequenceStatus.RUNNING) + seq_group.reset_status() else: super()._allocate_and_set_running(seq_group) - def _set_status(self, - seq_group: SequenceGroup, - status_to: SequenceStatus, - status_from: SequenceStatus = None): + def set_status(self, + seq_group: SequenceGroup, + status_to: SequenceStatus, + status_from: SequenceStatus = None): for seq in seq_group.get_seqs(status=status_from): seq.status = status_to diff --git a/llumnix/backends/vllm/sequence.py b/llumnix/backends/vllm/sequence.py index 59a6ef08..5964f96d 100644 --- a/llumnix/backends/vllm/sequence.py +++ b/llumnix/backends/vllm/sequence.py @@ -52,13 +52,15 @@ def arrival_time(self) -> float: @property def status(self) -> RequestStatus: + if self._status: + return self._status status = self.get_seqs()[0].status - assert status in [SequenceStatus.RUNNING, SequenceStatus.WAITING], \ - "Only RUNNING, WAITING are expected status for LlumnixRequest" if status == SequenceStatus.RUNNING: request_status = RequestStatus.RUNNING - else: + elif status == SequenceStatus.WAITING: request_status = RequestStatus.WAITING + else: + request_status = RequestStatus.FINISHED return request_status @property diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 36f53944..c19e780d 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -148,20 +148,21 @@ async def _migrate_out_one_request(self, migrate_out_request: LlumnixRequest, ds dst_instance_id = dst_instance_name[len("instance_"):] logger.info("{}->{} begin migrate out".format(self.instance_id, dst_instance_id)) migrated_request = [] - assert migrate_out_request.status in [RequestStatus.WAITING, RequestStatus.RUNNING], "Only migrate out waiting and running request" if migrate_out_request.status == RequestStatus.RUNNING: status = await self.migration_coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request) - else: + elif migrate_out_request.status == RequestStatus.WAITING: status = await self.migration_coordinator.migrate_out_waiting_request(migrate_in_ray_actor, migrate_out_request) + else: + return migrated_request if status == MigrationStatus.FINISHED_DONE: await migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request) - if migrate_out_request.status == RequestStatus.RUNNING: - self.backend_engine.free_src_request(migrate_out_request) + self.backend_engine.free_src_request(migrate_out_request) self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) migrated_request.append(migrate_out_request.request_id) else: # FINISHED_SRC_ABORTED or FINISHED_DST_ABORTED migrate_out_request.reset_migration_args_src() - # If dst aborts itself, dst proactively frees the pre alloc cache during pre alloc. + migrate_out_request.reset_status() + # If dst aborts itself, dst proactively frees the pre allocated cache in migrate_in_pre_alloc. if status == MigrationStatus.FINISHED_SRC_ABORTED: await migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id) t1 = time.time() @@ -218,9 +219,12 @@ def clear_migration_states(self, is_migrate_in: bool) -> None: migrating_out_requests_last_stage = self.backend_engine.pop_migrating_out_requests_last_stage() for backend_request in migrating_out_requests_last_stage: logger.info("clear_migration_states: add request {} back to engine".format(backend_request.request_id)) - if backend_request.status == RequestStatus.RUNNING: + assert backend_request.status in [RequestStatus.WAITING_MIGRATING, RequestStatus.RUNNING_MIGRATING], \ + "The status of request in migrating_out_requests_last_stage should be \ + RequestStatus.WAITING_MIGRATING or RequestStatus.RUNNING_MIGRATING" + if backend_request.status == RequestStatus.RUNNING_MIGRATING: self.backend_engine.add_running_request(backend_request) - else: # RequestStatus.WAITING + elif backend_request.status == RequestStatus.WAITING_MIGRATING: self.backend_engine.add_waiting_request(backend_request) def execute_migration_method(self, method, *args, **kwargs): diff --git a/llumnix/llumlet/request.py b/llumnix/llumlet/request.py index 0b864045..d92e6564 100644 --- a/llumnix/llumlet/request.py +++ b/llumnix/llumlet/request.py @@ -23,6 +23,9 @@ class RequestInferenceType(str, Enum): class RequestStatus(str, Enum): RUNNING = "running" WAITING = "waiting" + FINISHED = "finished" + RUNNING_MIGRATING = "running_migrating" + WAITING_MIGRATING = "waiting_migrating" class LlumnixRequest: def __init__(self, request_id: int, server_info: ServerInfo, expected_steps: int) -> None: @@ -37,7 +40,7 @@ def __init__(self, request_id: int, server_info: ServerInfo, expected_steps: int self.stage_timestamps = [] self.stage_num_blocks_list = [] self.try_schedule_times = 0 - self.waiting_migrating = False + self._status = None # end-of-migration, for multiple requests migration self.eom = False @@ -56,6 +59,12 @@ def reset_migration_args_src(self): self.stage_timestamps = [] self.stage_num_blocks_list = [] + def reset_status(self): + self._status = None + + def set_status(self, status: RequestStatus): + self._status = status + @property def inference_type(self) -> RequestInferenceType: raise NotImplementedError diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index e35ff8fd..a8d72c6d 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -66,7 +66,7 @@ def parse_manager_log_file(log_file): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) -@pytest.mark.parametrize("migrated_request_status", ['running', 'waiting']) +@pytest.mark.parametrize("migrated_request_status", ['waiting', 'running']) async def test_migration_benchmark(model, migration_backend, migrated_request_status): if migrated_request_status == 'waiting' and migration_backend != 'rpc': pytest.skip("When the migrated request status is waiting, only test the rpc migration backend.") @@ -104,8 +104,8 @@ async def run_bench_command(command): parse_manager_log_file("manager_instance.csv") - average_speed = parse_instance_log_file(instance_output_logs) if migrated_request_status == 'running': + average_speed = parse_instance_log_file(instance_output_logs) sorted_keys = sorted(average_speed.keys(), key=lambda x: float(x.split()[0])) data = [ ['migration_size'] + sorted_keys, diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index 3dffe8a3..865f28a7 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -294,10 +294,12 @@ def test_clear_migration_states(): llumlet.clear_migration_states(is_migrate_in=True) assert len(llumlet.backend_engine.pre_alloc("0", RequestStatus.RUNNING, 0.0, num_gpu_blocks)) == num_gpu_blocks _, seq_group = create_dummy_prompt("0",7,block_size,SequenceStatus.RUNNING) + seq_group.set_status(RequestStatus.RUNNING_MIGRATING) llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group) llumlet.clear_migration_states(is_migrate_in=False) assert len(llumlet.backend_engine.get_running_queue()) == 1 _, seq_group = create_dummy_prompt("0",7,block_size,SequenceStatus.WAITING) + seq_group.set_status(RequestStatus.WAITING_MIGRATING) llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group) llumlet.clear_migration_states(is_migrate_in=False) assert len(llumlet.backend_engine.get_waiting_queue()) == 1 diff --git a/tests/unit_test/backends/vllm/test_scheduler.py b/tests/unit_test/backends/vllm/test_scheduler.py index 10874edd..c8a03981 100644 --- a/tests/unit_test/backends/vllm/test_scheduler.py +++ b/tests/unit_test/backends/vllm/test_scheduler.py @@ -203,12 +203,12 @@ def test_schedule_running(): before_arrival = time.time() _, seq_group = create_dummy_prompt("1", prompt_length=1, block_size=2, expected_steps=math.inf) after_arrival = time.time() - blocks = scheduler.pre_alloc("2", RequestStatus.WAITING, after_arrival, 2) + blocks = scheduler.pre_alloc("2", RequestStatus.WAITING_MIGRATING, after_arrival, 2) assert len(blocks) == 2 scheduler.add_waiting_request(seq_group) - blocks = scheduler.pre_alloc("3", RequestStatus.WAITING, after_arrival, 2) + blocks = scheduler.pre_alloc("3", RequestStatus.WAITING_MIGRATING, after_arrival, 2) assert len(blocks) == 0 - blocks = scheduler.pre_alloc("4", RequestStatus.WAITING, before_arrival, 2) + blocks = scheduler.pre_alloc("4", RequestStatus.WAITING_MIGRATING, before_arrival, 2) assert len(blocks) == 2 def test_try_schedule_times():