Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Asynchronous llumlet #56

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
raise NotImplementedError

@abstractmethod
def _start_engine_step_loop(self) -> None:
async def _start_engine_step_loop(self) -> None:
"""Start step loop of backend engine.
"""
raise NotImplementedError
Expand Down Expand Up @@ -244,7 +244,7 @@ def free_src_request(self, backend_request: LlumnixRequest) -> None:
raise NotImplementedError

@abstractmethod
def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]):
async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]):
"""
Sends cache blocks from the source instance to the destination instance.

Expand Down
26 changes: 14 additions & 12 deletions llumnix/backends/vllm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# limitations under the License.

import time
import asyncio

from collections import defaultdict
from typing import List, Optional, Tuple
import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
# pylint: disable=unused-import
from ray.util.placement_group import PlacementGroup

from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayWorkerWrapper, get_distributed_init_method,\
get_ip, get_vllm_instance_id, get_open_port
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync, RayWorkerWrapper,\
get_distributed_init_method, get_ip, get_vllm_instance_id, get_open_port

from vllm import envs
from vllm.sequence import Logprob, SequenceOutput, SequenceGroupOutput, SamplerOutput, ExecuteModelRequest
Expand All @@ -34,7 +36,7 @@

logger = init_logger(__name__)

class LlumnixRayGPUExecutor(RayGPUExecutor):
class LlumnixRayGPUExecutor(RayGPUExecutorAsync):
node_id: str = None
migration_config: MigrationConfig = None

Expand Down Expand Up @@ -157,17 +159,17 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
cache_config=self.cache_config,
parallel_config=self.parallel_config)

def execute_model(self, *args, **kwargs):
async def execute_model_async(self, *args, **kwargs):
t0 = time.time()
outputs = super().execute_model(*args, **kwargs)
outputs = await super().execute_model_async(*args, **kwargs)
t1 = time.time()
self.last_inference_latency = (t1 - t0) * 1000
return outputs

class SimGPUExecutor(GPUExecutor):
class SimGPUExecutor(RayGPUExecutor):
latency_mem: LatencyMemData = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
RayGPUExecutor.__init__(self, *args, **kwargs)
self.last_inference_latency = 0
self.migration_bandwidth = self.latency_mem.migration_bandwidth
# TODO(ZeldaHuang): add swap bandwidth
Expand All @@ -191,7 +193,7 @@ def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)

def execute_model(
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
prefill_seq_len = 0
Expand All @@ -213,7 +215,7 @@ def execute_model(
decode_meta_data = (decode_bs, decode_seq_len)
latency += self.latency_mem.decode_latency[decode_meta_data][0] if decode_meta_data in self.latency_mem.decode_latency \
else model_decode((decode_bs, decode_seq_len), *self.latency_mem.decode_model_params)
time.sleep(latency/1000)
await asyncio.sleep(latency/1000)
sampler_outputs = []
for meta_data in execute_model_req.seq_group_metadata_list:
samples = []
Expand All @@ -225,6 +227,6 @@ def execute_model(
sampler_outputs.append(output)
return [SamplerOutput(outputs=sampler_outputs)]

def send_blocks(self, blocks_len) -> None:
async def send_blocks(self, blocks_len) -> None:
migration_latency = (self.cache_block_size * blocks_len) / self.migration_bandwidth
time.sleep(migration_latency)
await asyncio.sleep(migration_latency)
107 changes: 49 additions & 58 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy

from vllm.engine.llm_engine import LLMEngine
from vllm.engine.async_llm_engine import _AsyncLLMEngine
from vllm.core.scheduler import ScheduledSequenceGroup
from vllm.outputs import RequestOutput
from vllm.sequence import SequenceGroup, SequenceStatus, SamplerOutput, SequenceGroupMetadata
Expand Down Expand Up @@ -82,7 +82,7 @@ async def put_nowait_to_servers(self,
logger.error("exception traceback: {}".format(traceback.format_exc()))


class LLMEngineLlumnix(LLMEngine):
class LLMEngineLlumnix(_AsyncLLMEngine):
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
Expand Down Expand Up @@ -171,38 +171,37 @@ def _process_model_outputs(
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[List[RequestOutput], List[ServerInfo]]:
# ensure scheduled_seq_groups matching output
with self.scheduler.scheduler_lock:
server_infos = []
if output:
new_output = []
new_scheduled_seq_groups = []
new_seq_group_metadata_list = []
for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs):
seq_group = scheduled_seq_group.seq_group
if seq_group.get_seqs(SequenceStatus.RUNNING):
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
server_infos.append(seq_group.server_info)
scheduled_seq_groups = new_scheduled_seq_groups
output[0].outputs = new_output
seq_group_metadata_list = new_seq_group_metadata_list
for ignored_seq_group in ignored_seq_groups:
server_infos.append(ignored_seq_group.server_info)
for server_info in server_infos:
if hasattr(server_info, 'request_timestamps'):
server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time()
request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list)
for request_output, server_info in zip(request_outputs, server_infos):
if hasattr(server_info, 'request_timestamps'):
request_output.request_timestamps = server_info.request_timestamps
request_output.request_timestamps.engine_process_model_outputs_timestamp_end = time.time()
# TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_infos

def step(self) -> None:
server_infos = []
if output:
new_output = []
new_scheduled_seq_groups = []
new_seq_group_metadata_list = []
for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs):
seq_group = scheduled_seq_group.seq_group
if seq_group.get_seqs(SequenceStatus.RUNNING):
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
server_infos.append(seq_group.server_info)
scheduled_seq_groups = new_scheduled_seq_groups
output[0].outputs = new_output
seq_group_metadata_list = new_seq_group_metadata_list
for ignored_seq_group in ignored_seq_groups:
server_infos.append(ignored_seq_group.server_info)
for server_info in server_infos:
if hasattr(server_info, 'request_timestamps'):
server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time()
request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list)
for request_output, server_info in zip(request_outputs, server_infos):
if hasattr(server_info, 'request_timestamps'):
request_output.request_timestamps = server_info.request_timestamps
request_output.request_timestamps.engine_process_model_outputs_timestamp_end = time.time()
# TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_infos

async def step_async(self) -> None:
step_begin_time = time.time()
request_outputs, server_infos = super().step()
request_outputs, server_infos = await super().step_async()
for request_output in request_outputs:
if hasattr(request_output, 'request_timestamps'):
request_output.request_timestamps.engine_step_timestamp_begin = step_begin_time
Expand Down Expand Up @@ -251,7 +250,6 @@ def add_request(self, request_id: str, server_info: ServerInfo, expected_steps:
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]],
seq_group.sampling_params, seq_group.metrics.arrival_time, seq_group.lora_request,
seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _start_put_queue_loop(self):
while True:
Expand Down Expand Up @@ -301,45 +299,38 @@ def __init__(
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group, node_id=node_id)

self.state_lock = threading.Lock()
self.state = EngineState.INIT
logger.info("engine ({}) current state {}".format(self.instance_id, self.state))

self._stop_event = threading.Event()
self.engine_step_loop_thread = threading.Thread(
target=self._start_engine_step_loop, args=(), daemon=True, name="engine_step_loop"
)
self.engine_step_loop_thread.start()
self._stop_event = asyncio.Event()
asyncio.create_task(self._start_engine_step_loop())

def _start_engine_step_loop(self) -> None:
async def _start_engine_step_loop(self) -> None:
self._stop_event.clear()

with self.state_lock:
previous_state = self.state
self.state = EngineState.RUNNING
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))
previous_state = self.state
self.state = EngineState.RUNNING
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))

while not self._stop_event.is_set():
try:
request_outputs, _ = self.engine.step()
request_outputs, _ = await self.engine.step_async()
if len(request_outputs) == 0:
time.sleep(0.01)
await asyncio.sleep(0.01)
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))
self._run_workers("shutdown")

with self.state_lock:
previous_state = self.state
self.state = EngineState.CRASHED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))
previous_state = self.state
self.state = EngineState.CRASHED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))
break

with self.state_lock:
if self.state == EngineState.RUNNING:
self.state = EngineState.STOPPED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state))
if self.state == EngineState.RUNNING:
self.state = EngineState.STOPPED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state))

def execute_worker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)
Expand All @@ -362,12 +353,12 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
backend_request.reset_migration_args()
self.add_running_request(backend_request)

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
ray.get(dst_ray_actor.execute_engine_method.remote("_run_workers",
async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
await dst_ray_actor.execute_engine_method.remote("_run_workers",
"migrate_cache",
dst_blocks=dst_blocks,
src_blocks=src_blocks,
src_worker_handle_list=self.worker_handle_list))
src_worker_handle_list=self.worker_handle_list)

def _run_workers(self, *args, **kwargs):
# pylint: disable=protected-access
Expand Down
23 changes: 0 additions & 23 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from asyncio.log import logger
import time
import threading
from typing import Dict, List, Optional, Tuple
from collections import deque

Expand All @@ -23,7 +22,6 @@
from llumnix.instance_info import InstanceInfo
from llumnix.logger import init_logger
from llumnix.llumlet.request import RequestInferenceType
from llumnix.backends.vllm.utils import scheduler_lock
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix

logger = init_logger(__name__)
Expand Down Expand Up @@ -56,7 +54,6 @@ def __init__(self, *args, **kwargs) -> None:
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
self.pre_alloc_cache_dict: Dict[str, BlockTable] = {}
self.scheduler_lock = threading.Lock()
self.migrating_out_request_last_stage: List[SequenceGroupLlumnix] = []

def add_update_instance_info_callback(self, update_instance_info_callback):
Expand All @@ -79,25 +76,21 @@ def _get_num_killed_requests(self) -> int:
cnt += 1
return cnt

@scheduler_lock
def get_running_queue(self):
return self.running

@scheduler_lock
def get_all_request_ids(self) -> List[str]:
request_ids : List[str] = []
for state_queue in [self.waiting, self.running, self.swapped]:
for seq_group in state_queue:
request_ids.append(seq_group.request_id)
return request_ids

@scheduler_lock
def get_request_incremental_blocks(self, backend_request: SequenceGroupLlumnix, pre_stage_num_blocks: int) -> List[int]:
seq = backend_request.get_seqs()[0]
blocks = self.block_manager.get_block_table(seq)
return blocks[pre_stage_num_blocks:]

@scheduler_lock
def remove_running_request(self, request_id: str) -> None:
for seq_group in self.running:
if seq_group.request_id == request_id:
Expand All @@ -117,7 +110,6 @@ def pop_migrating_out_requests_last_stage(self) -> List[SequenceGroupLlumnix]:
self.migrating_out_request_last_stage.clear()
return migrating_out_request_last_stage

@scheduler_lock
def pre_alloc(self, request_id: str, block_num: int) -> List[int]:
blocks = self.block_manager.get_free_blocks(block_num)
pre_blocks = self.pre_alloc_cache_dict.get(request_id, [])
Expand All @@ -126,17 +118,14 @@ def pre_alloc(self, request_id: str, block_num: int) -> List[int]:
blocks = [block.block_number for block in blocks]
return blocks

@scheduler_lock
def add_running_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
self.running.append(backend_request)

@scheduler_lock
def is_request_running(self, backend_request: SequenceGroupLlumnix) -> bool:
return backend_request in self.running

@scheduler_lock
def free_dst_pre_alloc_cache(self, request_id: str = None) -> None:
if request_id:
blocks = self.pre_alloc_cache_dict.pop(request_id, [])
Expand All @@ -150,7 +139,6 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None:
# pylint: disable=protected-access
self.block_manager._free_block_table(blocks)

@scheduler_lock
def free_src_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
logger.info("free seq {}".format(seq.seq_id))
Expand Down Expand Up @@ -201,7 +189,6 @@ def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) -
instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.is_finished()]
return instance_info

@scheduler_lock
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
seq_group_metadata_list, scheduler_outputs = super().schedule()
self.update_instance_info_callback(self._get_instance_info([scheduled_seq_group.seq_group \
Expand All @@ -220,13 +207,3 @@ def _schedule_running(self, running_queue: deque, *args, **kwargs):
for seq_group in remove_running:
remaining_running.extend([seq_group])
return remaining_running, running_scheduled

def add_seq_group(self, *args, **kwargs):
# The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function.
# pylint: disable=R1732
self.scheduler_lock.acquire()
return super().add_seq_group(*args, **kwargs)

@scheduler_lock
def abort_seq_group(self, *args, **kwargs):
return super().abort_seq_group(*args, **kwargs)
Loading