Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Oct 9, 2024
1 parent 5258e78 commit 2f768d5
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 1 deletion.
9 changes: 8 additions & 1 deletion llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,19 @@ def __init__(
self.state_lock = threading.Lock()
self.state = EngineState.INIT

self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
)
self._thread.start()

def _start_engine_loop(self) -> None:
self._stop_event.clear()

with self.state_lock:
self.state = EngineState.RUNNING

while True:
while not self._stop_event.is_set():
try:
self.engine.step()
# pylint: disable=broad-except
Expand All @@ -263,6 +266,10 @@ def _start_engine_loop(self) -> None:
self.state = EngineState.CRASHED
break

with self.state_lock:
if self.state == EngineState.RUNNING:
self.state = EngineState.STOPPED

def execute_worker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)

Expand Down
4 changes: 4 additions & 0 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def check_state(self):

with self.backend_engine.state_lock:
if self.backend_engine.state == EngineState.CRASHED:
self.backend_engine._stop_event.set()
if self.backend_engine._thread.is_alive():
self.backend_engine._thread.join()

self_actor = ray.get_actor(self.actor_name)
ray.kill(self_actor)

Expand Down
90 changes: 90 additions & 0 deletions tests/unit_test/llumlet/test_engine_step_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time
import ray
import torch
import pytest

from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from vllm.engine.arg_utils import EngineArgs

from llumnix.backends.backend_interface import BackendType
from llumnix.llumlet.llumlet import Llumlet
from llumnix.internal_config import MigrationConfig
# pylint: disable=unused-import
from tests.conftest import setup_ray_env

@ray.remote(num_cpus=1, max_concurrency=4)
class MockLlumlet(Llumlet):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.origin_step = self.backend_engine.engine.step

def set_error_step(self, broken: bool):
self.backend_engine._stop_event.set()
if self.backend_engine._thread.is_alive():
self.backend_engine._thread.join()

def raise_error_step():
self.backend_engine.engine.step()
raise ValueError("Mock engine step error")

if broken:
self.backend_engine.engine.step = raise_error_step
else:
self.backend_engine.engine.step = self.origin_step

self.backend_engine._thread = threading.Thread(
target=self.backend_engine._start_engine_loop, args=(), daemon=True, name="engine_loop"
)
self.backend_engine._thread.start()

@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.")
def test_engine_step_exception(setup_ray_env):
engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True)
migration_config = MigrationConfig("LCFS", "rpc", 16, 1, 4, 5, 20)
node_id = ray.get_runtime_context().get_node_id()
scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False)

origin_free_memory, _ = torch.cuda.mem_get_info()

actor_name = "instance_0"
llumlet = MockLlumlet.options(name=actor_name, namespace='llumnix',
scheduling_strategy=scheduling_strategy).remote(
instance_id="0",
backend_type=BackendType.VLLM,
migration_config=migration_config,
engine_args=engine_args,
node_id=node_id
)
ray.get(llumlet.is_ready.remote())

all_actors = ray.util.list_named_actors(True)
all_actor_names = [actor["name"] for actor in all_actors]
assert actor_name in all_actor_names

cur_free_memory, _ = torch.cuda.mem_get_info()
assert cur_free_memory < origin_free_memory

ray.get(llumlet.set_error_step.remote(True))
time.sleep(3)

all_actors = ray.util.list_named_actors(True)
all_actor_names = [actor["name"] for actor in all_actors]
assert actor_name not in all_actor_names

cur_free_memory, _ = torch.cuda.mem_get_info()
assert origin_free_memory == cur_free_memory

0 comments on commit 2f768d5

Please sign in to comment.