Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Feb 7, 2025
1 parent 03e5173 commit ea8764e
Show file tree
Hide file tree
Showing 22 changed files with 71 additions and 106 deletions.
45 changes: 0 additions & 45 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import argparse
import dataclasses
from dataclasses import dataclass
import argparse
from typing import List, Tuple, Union

from llumnix.internal_config import GlobalSchedulerConfig, MigrationConfig
Expand Down Expand Up @@ -454,47 +453,3 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
type=int,
help='if the number pf remain blocks < migration_last_stage_max_blocks, do last stage migration')
return parser

@dataclass
class InstanceArgs:
instance_type: str = None

def __post_init__(self):
# Check if all fields default to None
for field_info in dataclasses.fields(self):
if field_info.default is not None:
raise ValueError(f"The default value of '{field_info.name}' should be None")

for attr in dataclasses.fields(self):
if getattr(self, attr.name) is None:
setattr(self, attr.name, getattr(_C.INSTANCE, attr.name.upper()))

@classmethod
def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'InstanceArgs':
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
# The defalut values of attributes are defined in default.py.
instance_args = cls(**{attr: getattr(cfg.INSTANCE, attr.upper()) for attr in attrs})
return instance_args

@classmethod
def check_args(cls, args: 'InstanceArgs', manager_args: EngineManagerArgs, parser: argparse.ArgumentParser):
# pylint: disable=protected-access
for action in parser._optionals._actions:
if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest):
assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}."

# instance_type check
if manager_args.enable_pd_disagg:
assert args.instance_type in ['prefill', 'decode'], \
"instance_type should be prefill or decode if enable_pd_disagg is set."

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--instance-type',
type=str,
choices=['prefill', 'decode', 'no_constraints'],
help='instance type for the engine')

return parser
19 changes: 10 additions & 9 deletions llumnix/backends/bladellm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
from functools import partial
import json
import traceback
from typing import Dict, List, Optional, Tuple, Union, Iterable, Any
from typing import List, Optional, Tuple, Union, Iterable, Any
from collections import defaultdict
import threading
import asyncio
import queue

import ray
import grpc
import ray.actor
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

Expand Down Expand Up @@ -119,24 +118,24 @@ def _put_request_outputs_to_server(self, request_outputs: List[GenerateStreamRes
# pylint: disable=unused-argument
async def send(self, req_id, msg, reset=False):
if req_id not in self.request_server_map:
logger.warning("req_id {} not in request_server_map, maybe already be migrated.", req_id)
logger.warning("req_id {} not in request_server_map, maybe already be migrated.".format(req_id))
return

self.put_queue_args_queue.put_nowait((msg, str(req_id), self.request_server_map[req_id]))
if msg.is_finished:
logger.debug(f"trans_wrapper finish_request {req_id}")
logger.debug("trans_wrapper finish_request {}".format(req_id))
self.request_server_map.pop(req_id, None)

async def recv(self):
return None

def drop_request(self, request_id: int) -> None:
self.request_server_map.pop(request_id, None)
logger.debug("trans_wrapper drop_request {}", request_id)
logger.debug("trans_wrapper drop_request {}".format(request_id))

def add_request(self, request_id: int, server_info: ServerInfo) -> None:
self.request_server_map[request_id] = server_info
logger.debug("trans_wrapper add_request {} {}", request_id, server_info)
logger.debug("trans_wrapper add_request {} {}".format(request_id, server_info))

def clear(self):
self.request_server_map = {}
Expand All @@ -152,6 +151,8 @@ def __init__(self,
placement_group: PlacementGroup,
request_output_queue_type: QueueType,
migration_config: MigrationConfig,
src_worker_ip_address: List[str],
request_barriers: queue.Queue,
) -> None:
self.instance_id = instance_id

Expand Down Expand Up @@ -208,7 +209,7 @@ def _update_request_inference_type(self, resp_list: List[WorkerStepResponse]):
RequestInferenceType.DECODE if num_out_token > 0 else RequestInferenceType.PREFILL

async def update_callback(self, resp_list, step_requests):
logger.debug("update_callback {} {}", resp_list, step_requests)
logger.debug("update_callback {} {}".format(resp_list, step_requests))
await super().update_callback(resp_list, step_requests)
self._update_request_inference_type(resp_list)
self.scheduler.llumnix_metrics.engine_step_metrics(self.scheduler)
Expand Down Expand Up @@ -257,7 +258,7 @@ async def add_request(self, server_info: ServerInfo, server_request: ServerReque
await self._client._add_request(server_request)

async def drop_request(self, req_id: int):
logger.debug("engine {} drop request {}", self.instance_id, req_id)
logger.debug("engine {} drop request {}".format(self.instance_id, req_id))
await self._client.drop_request(req_id)

async def run_workers(self, worker_method, *args, **kwargs):
Expand Down Expand Up @@ -506,4 +507,4 @@ def commit_dst_request(self, backend_request: LlumnixRequest) -> None:
self.engine.scheduler.llumnix_metrics.scheduler_step_metrics(self.engine.scheduler)
if self.engine._migration_semaphore.locked():
self.engine._migration_semaphore.release()
logger.debug("commit dst request {}", backend_request.request_id)
logger.debug("commit dst request {}".format(backend_request.request_id))
8 changes: 6 additions & 2 deletions llumnix/backends/bladellm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
for layer_idx in range(self.kv_cache_arena.num_layers):
self.kv_cache_arena.events[layer_idx].wait()

# pylint: disable=unused-argument
# pylint: disable=unused-argument,arguments-differ
def do_send(self, request, context):
blocks = list(request.src_blocks)

Expand Down Expand Up @@ -196,6 +196,7 @@ def do_send(self, request, context):

return responce

# pylint: disable=unused-argument,arguments-differ
def do_recv(self, src_handle, blocks: List[int]):
# use pin memory dummy_cache to speed up data transfer
num_blocks = len(blocks)
Expand Down Expand Up @@ -291,7 +292,7 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
self.check_recv_done(src_handle.instance_id, src_handle.worker_id, migration_kv_uuid,
src_blocks, dst_blocks)

# pylint: disable=unused-argument
# pylint: disable=unused-argument,arguments-differ
def do_send(self, request, context):
dst_instance_id, dst_worker_id = request.dst_instance_id, request.dst_worker_id
self.client_kv.add_worker(dst_instance_id, dst_worker_id, 0, len(self.kv_cache), self.tranfer_type)
Expand All @@ -306,6 +307,9 @@ def do_send(self, request, context):

return migration_worker_pb2.SendKvCacheResponse()

def do_recv(self, request, context):
pass

def check_recv_done(self, src_instance_id, src_worker_id, kv_request_id: str,
src_blocks: List[int], dst_blocks: List[int]):
timeout_threshold_ms = 30
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/bladellm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def finished(self) -> bool:
return self._status == RequestStatus.FINISHED

@property
def arrival_time(self) -> float:
def request_arrival_time(self) -> float:
return self.receive_time

@property
Expand Down
1 change: 1 addition & 0 deletions llumnix/backends/bladellm/worker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def barrier(self, callback: Optional[Callable] = None):
meta_requests = [WorkerMetaRequest(method="barrier").SerializeToString()]

if self.dispatch_mode == ResponseMode.ONE:
# pylint: disable=expression-not-assigned
[self.rpc_call(meta_requests[0], i) for i in range(len(self.reader))]
elif self.dispatch_mode == ResponseMode.ALL:
for i, request in enumerate(meta_requests):
Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/migration_backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
raise NotImplementedError

@abstractmethod
def do_send(self, dst_handle, blocks: List[int], virtuel_engine: int):
def do_send(self, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def do_recv(self, src_handle, blocks: List[int], virtuel_engine: int):
def do_recv(self, *args, **kwargs):
raise NotImplementedError
21 changes: 10 additions & 11 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from llumnix.internal_config import MigrationConfig
from llumnix.queue.utils import QueueType
from llumnix.backends.utils import AsyncPutQueueActor
from llumnix.utils import get_instance_name
from llumnix.llumlet.request import LlumnixRequest
from llumnix.utils import get_instance_name

logger = init_logger(__name__)

Expand Down Expand Up @@ -198,6 +198,7 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None:
instance_info.num_blocks_last_running_request = self.instance_info.num_blocks_last_running_request
self.instance_info = instance_info

# pylint: disable=invalid-overridden-method
async def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs):
super().add_request(request_id, *args, **kwargs)
seq_group = self.scheduler[0].waiting[-1]
Expand Down Expand Up @@ -297,14 +298,9 @@ async def is_ready(self) -> bool:
def execute_worker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)

async def add_request(self,
request_id: str,
server_info: ServerInfo,
expected_steps: int,
*args,
**kwargs) -> None:
# Store the server information of each request to put the request outputs back to the corresponding api server correctly.
self.engine.add_request(request_id, server_info, expected_steps, *args, **kwargs)
# Store the server information of each request to put the request outputs back to the corresponding api server correctly.
async def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs) -> None:
await self.engine.add_request(request_id, server_info, expected_steps, *args, **kwargs)

def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
Expand Down Expand Up @@ -346,8 +342,11 @@ def get_running_queue(self) -> List[SequenceGroupLlumnix]:
def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]:
return self.engine.scheduler[0].get_waiting_queue()

async def get_request_incremental_blocks(self, *args, **kwargs) -> Tuple[List[int], List[int]]:
return self.engine.scheduler[0].get_request_incremental_blocks(*args, **kwargs)
async def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> Tuple[List[int], List[int]]:
incremental_blocks, incremental_token_ids = \
self.engine.scheduler[0].get_request_incremental_blocks(backend_request, pre_stage_num_blocks)
is_last_stage = (len(incremental_blocks) <= self.migration_config.migration_last_stage_max_blocks) or backend_request.blocking_migration
return incremental_blocks, incremental_token_ids, is_last_stage

def remove_running_request(self, *args, **kwargs) -> None:
return self.engine.scheduler[0].remove_running_request(*args, **kwargs)
Expand Down
6 changes: 5 additions & 1 deletion llumnix/backends/vllm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
recv_blocks = dst_blocks[start_idx:start_idx+offset]
self.do_recv(rpc_numpy_cache, recv_blocks)

def do_send(self, dst_handle, blocks: List[int], virtuel_engine: int=0):
# pylint: disable=arguments-differ
def do_send(self, blocks: List[int], virtuel_engine: int=0):
num_blocks = len(blocks)
send_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
# src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)}
Expand All @@ -120,6 +121,7 @@ def do_send(self, dst_handle, blocks: List[int], virtuel_engine: int=0):
torch.cuda.Stream.synchronize(self.migration_stream)
return send_cache.to(self.rpc_dtype).numpy()

# pylint: disable=arguments-differ
def do_recv(self, src_handle, blocks: List[int], virtuel_engine: int=0):
num_blocks = len(blocks)
# src_to_dst = dict(enumerate(blocks))
Expand Down Expand Up @@ -261,6 +263,7 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", self.global_rank, send_blocks)
self.do_recv(src_rank, recv_blocks)

# pylint: disable=unused-argument,arguments-differ
def do_send(self, dst_handle, blocks: List[int], virtuel_engine: int=0):
num_blocks = len(blocks)
send_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
Expand All @@ -280,6 +283,7 @@ def do_send(self, dst_handle, blocks: List[int], virtuel_engine: int=0):
col.send(send_cache, dst_handle, self.group_name)
self.migration_stream.synchronize()

# pylint: disable=unused-argument,arguments-differ
def do_recv(self, src_handle, blocks: List[int], virtuel_engine: int=0):
num_blocks = len(blocks)
src_to_dst: List[Tuple[int, int]] = []
Expand Down
2 changes: 1 addition & 1 deletion llumnix/entrypoints/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def launch_ray_cluster(port: int) -> subprocess.CompletedProcess:
sys.exit(1)
ray_start_command = None
if 'HEAD_NODE' in os.environ:
ray_start_command = f"ray start --head --node-ip-address={node_ip_address} --port={port} --log-dir=/mnt/xinyi/custom_logs"
ray_start_command = f"ray start --head --node-ip-address={node_ip_address} --port={port}"
try:
result = subprocess.run(['ray', 'start', '--head', f'--port={port}'], check=True, text=True, capture_output=True)
except subprocess.CalledProcessError as e:
Expand Down
1 change: 0 additions & 1 deletion llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@

class Llumlet:
def __init__(self,
instance_args: InstanceArgs,
instance_id: str,
instance_args: InstanceArgs,
placement_group: PlacementGroup,
Expand Down
6 changes: 2 additions & 4 deletions llumnix/llumlet/migration_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
import traceback
import enum
from typing import List
import traceback

import ray

from llumnix.logging.logger import init_logger
from llumnix.llumlet.request import LlumnixRequest, RequestStatus
Expand Down Expand Up @@ -119,7 +116,8 @@ async def _migrate_out_onestage(self,
return MigrationStatus.ABORTED_SRC

pre_stage_num_blocks = sum(migrate_out_request.stage_num_blocks_list)
incremental_blocks, is_last_stage = await self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks)
incremental_blocks, incremental_token_ids, is_last_stage = \
await self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks)

if migrate_out_request.should_abort_migration():
return MigrationStatus.ABORTED_SRC
Expand Down
18 changes: 9 additions & 9 deletions llumnix/llumlet/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,17 @@ def status(self) -> RequestStatus:
def prefill_num_blocks(self) -> int:
raise NotImplementedError

@property
def n_blocks(self) -> int:
raise NotImplementedError
# @property
# def n_blocks(self) -> int:
# raise NotImplementedError

@property
def token_ids(self) -> int:
raise NotImplementedError
# @property
# def token_ids(self) -> int:
# raise NotImplementedError

@property
def block_size(self) -> int:
raise NotImplementedError
# @property
# def block_size(self) -> int:
# raise NotImplementedError

# Whether the migration of request is completed within one stage. For requests that have already reached
# the expected steps, blocking_migration is True.
Expand Down
3 changes: 0 additions & 3 deletions llumnix/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,10 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) -

try:
migrate_instance_pairs = self.global_scheduler.pair_migration(pair_migration_type)
if len(migrate_instance_pairs) > 0:
logger.info(f"migrate_instance_pairs: {migrate_instance_pairs}")
migration_tasks = []
for _, migrate_instance_pair in enumerate(migrate_instance_pairs):
migrate_out_instance_id, migrate_in_instance_id = migrate_instance_pair
if self.instance_migrating[migrate_out_instance_id] or self.instance_migrating[migrate_in_instance_id]:
logger.info(f"migrate_instance_pairs: {migrate_instance_pairs} is migrating")
continue
self.instance_migrating[migrate_out_instance_id] = True
self.instance_migrating[migrate_in_instance_id] = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# pylint: disable=unused-import
from tests.conftest import ray_env
from .utils import (generate_launch_command, generate_serve_command, wait_for_llumnix_service_ready,
from tests.e2e_test.utils import (generate_launch_command, generate_serve_command, wait_for_llumnix_service_ready,
shutdown_llumnix_service)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# pylint: disable=unused-import
from tests.conftest import ray_env
from .utils import (generate_launch_command, generate_bench_command, to_markdown_table,
from tests.e2e_test.utils import (generate_launch_command, generate_bench_command, to_markdown_table,
wait_for_llumnix_service_ready, shutdown_llumnix_service)

size_pattern = re.compile(r'total_kv_cache_size:\s*([\d.]+)\s*(B|KB|MB|GB|KB|TB)')
Expand Down
5 changes: 3 additions & 2 deletions tests/unit_test/backends/vllm/test_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_llm_engine_from_engine_args_sim(ray_env):
placement_group=placement_group)
assert llm_engine.executor_class == SimGPUExecutor

def test_llm_engine_add_requset(ray_env):
@pytest.mark.asyncio
async def test_llm_engine_add_requset(ray_env):
engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True)
latency_data = LatencyMemData({},{},{})
placement_group = initialize_placement_group(get_placement_group_name("0"), num_cpus=1, num_gpus=0, detached=True)
Expand All @@ -81,7 +82,7 @@ def test_llm_engine_add_requset(ray_env):
migration_config=None)
sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100)
server_info = ServerInfo(None, None, None, None, None)
llm_engine.add_request("0", server_info, math.inf, "prompt", sampling_params)
await llm_engine.add_request("0", server_info, math.inf, "prompt", sampling_params)
assert len(llm_engine.scheduler[0].waiting) == 1
assert llm_engine.scheduler[0].waiting[-1].request_id == "0"
assert llm_engine.scheduler[0].waiting[-1].expected_steps == math.inf
Expand Down
Loading

0 comments on commit ea8764e

Please sign in to comment.