Skip to content

Commit

Permalink
[Core] Support one-to-many and many-to-one migration
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Nov 5, 2024
1 parent 188b08e commit 41829fd
Show file tree
Hide file tree
Showing 25 changed files with 274 additions and 108 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,6 @@ _build/
# hip files generated by PyTorch
*.hip
*_hip*

# TODO file
TODO
1 change: 1 addition & 0 deletions configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ MANAGER:

MIGRATION_BACKEND: 'gloo'
MIGRATION_CACHE_BLOCKS: 512
MIGRATION_INTERNAL_CACHE_SIZE: 1

ENABLE_SCALING: False
5 changes: 5 additions & 0 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--migration-num-layers MIGRATION_NUM_LAYERS]
[--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS]
[--max-stages MAX_STAGES]
[--migration-internal-cache-size MIGRATION_INTERNAL_CACHE_SIZE]
[--log-request-timestamps]
```
Expand Down Expand Up @@ -165,6 +166,10 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Drop migration if the number of stages > max_stages.
- Default: 3

`--migration-internal-cache-size`
- Number of internal cache size in migration backend for sending and receiving
- Default: 2

`--log-request-timestamps`
- Enable logging request timestamps.

Expand Down
7 changes: 6 additions & 1 deletion llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class EngineManagerArgs:
migration_num_layers: int = None
last_stage_max_blocks: int = None
max_stages: int = None
migration_internal_cache_size: int = None

enable_pd_disagg: bool = None

Expand Down Expand Up @@ -176,7 +177,8 @@ def create_migration_config(self) -> MigrationConfig:
self.migration_num_layers,
self.last_stage_max_blocks,
self.max_stages,
self.migration_backend_init_timeout)
self.migration_backend_init_timeout,
self.migration_internal_cache_size)
return migration_config

@classmethod
Expand Down Expand Up @@ -302,6 +304,9 @@ def add_cli_args(
parser.add_argument('--last-stage-max-blocks',
type=int,
help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration')
parser.add_argument('--migration-internal-cache-size',
type=int,
help='number of internal cache size in migration backend for sending and receiving')
parser.add_argument('--max-stages',
type=int,
help='drop migration if the number of stages > max_stages')
Expand Down
23 changes: 23 additions & 0 deletions llumnix/backends/migration_backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

from abc import ABC, abstractmethod
from typing import List
import queue

import torch

class MigrationBackendBase(ABC):
@abstractmethod
Expand All @@ -39,3 +41,24 @@ def do_send(self, dst_handle, blocks: List[int]):
@abstractmethod
def do_recv(self, src_handle, blocks: List[int]):
raise NotImplementedError

class CacheMigrationBackend(MigrationBackendBase):
def __init__(self, num_cache, cache_shape, cache_dtype, cache_device, pin_memory, *args, **kwargs):
super().__init__(*args, **kwargs)

self.num_cache = num_cache

self.dummy_cache = [
torch.empty(size=cache_shape, dtype=cache_dtype, device=cache_device, pin_memory=pin_memory)
for _ in range(self.num_cache)
]

self.avaiable_cache_queue = queue.Queue()
for i in range(self.num_cache):
self.avaiable_cache_queue.put_nowait(i)

def get_available_cache(self):
return self.avaiable_cache_queue.get()

def put_back_cache(self, cache_id):
self.avaiable_cache_queue.put_nowait(cache_id)
57 changes: 29 additions & 28 deletions llumnix/backends/vllm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import ray.util.collective as col
from vllm.worker.cache_engine import CacheEngine
from llumnix.internal_config import MigrationConfig
from llumnix.backends.migration_backend_interface import MigrationBackendBase
from llumnix.backends.migration_backend_interface import MigrationBackendBase, CacheMigrationBackend
from llumnix.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -40,17 +40,16 @@ def exec_method(self, is_driver_worker, handle, *args, **kwargs):

NUMPY_SUPPORTED_DTYPES = [torch.float32, torch.float16]

class RayRpcMigrationBackend(MigrationBackendBase):
class RayRpcMigrationBackend(CacheMigrationBackend):
def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, worker_rank, worker_handle_list, \
scheduling_strategy, is_driver_worker, gpu_cache) -> None:
super().__init__()

self.migration_config = migration_config
self.cache_engine = cache_engine

self.worker_rank = worker_rank
self.worker_handle_list = worker_handle_list
self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote()
self.migration_stream = torch.cuda.Stream()

self.rpc_dtype = self.cache_engine.dtype
if self.cache_engine.dtype in NUMPY_SUPPORTED_DTYPES:
Expand All @@ -65,14 +64,10 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
self.num_migration_cache_blocks = self.migration_config.migration_cache_blocks
self.num_layers = self.cache_engine.num_layers
self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size
cache_shape = (self.num_migration_cache_blocks, self.num_layers, 2, self.migration_cache_size)

self.dummy_cache = torch.empty(
size=(self.num_migration_cache_blocks, self.num_layers, 2, self.migration_cache_size),
dtype=self.cache_engine.dtype,
device=self.cache_device,
pin_memory=True
)
self.migration_stream = torch.cuda.Stream()
super().__init__(migration_config.migration_internal_cache_size, cache_shape, self.cache_engine.dtype,
self.cache_device, pin_memory=True)

def init_backend(self, group_name, world_size, rank) -> bool:
logger.info("create rpc migration backend successfully.")
Expand Down Expand Up @@ -100,31 +95,37 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
ray_obj = self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", None, send_blocks)
if rpc_numpy_cache is not None:
self.do_recv(rpc_numpy_cache, recv_blocks)
rpc_numpy_cache = ray.get(ray_obj)
rpc_numpy_cache_ref = ray.get(ray_obj)
rpc_numpy_cache = ray.get(rpc_numpy_cache_ref)
recv_blocks = dst_blocks[start_idx:start_idx+offset]
self.do_recv(rpc_numpy_cache, recv_blocks)

def do_send(self, dst_handle, blocks: List[int]):
num_blocks = len(blocks)
send_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
dummy_cache_idx = self.get_available_cache()
send_cache = self.dummy_cache[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)}
with torch.cuda.stream(self.migration_stream):
for layer_idx in range(self.num_layers):
self.cache_engine.attn_backend.swap_blocks(self.gpu_cache[layer_idx], send_cache[layer_idx], src_to_dst)
torch.cuda.Stream.synchronize(self.migration_stream)
return send_cache.to(self.rpc_dtype).numpy()
data = ray.put(send_cache.to(self.rpc_dtype).numpy())
self.put_back_cache(dummy_cache_idx)
return data

def do_recv(self, src_handle, blocks: List[int]):
num_blocks = len(blocks)
src_to_dst = dict(enumerate(blocks))
recv_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
dummy_cache_idx = self.get_available_cache()
recv_cache = self.dummy_cache[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
# use pin memory dummy_cache to speed up data transfer
recv_cache.copy_(torch.from_numpy(src_handle))

with torch.cuda.stream(self.migration_stream):
for layer_idx in range(self.num_layers):
self.cache_engine.attn_backend.swap_blocks(recv_cache[layer_idx], self.gpu_cache[layer_idx], src_to_dst)
torch.cuda.Stream.synchronize(self.migration_stream)
self.put_back_cache(dummy_cache_idx)

def try_import_gloo():
try:
Expand All @@ -139,11 +140,9 @@ def try_import_gloo():
except ImportError as e:
raise ImportError("Gloo is not installed. Please install it first.") from e

class RayColMigrationBackend(MigrationBackendBase):
class RayColMigrationBackend(CacheMigrationBackend):
def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, local_rank,
scheduling_strategy, is_driver_worker, gpu_cache) -> None:
super().__init__()

# pylint: disable=C0415
import cupy

Expand All @@ -162,6 +161,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote()
self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache
self.migration_stream = cupy.cuda.Stream()

self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size

Expand All @@ -172,14 +172,9 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
self.cache_device = torch.device(f"cuda:{self.local_rank}")

pin_memory = (self.backend == 'gloo')
self.dummy_cache = torch.empty(
size=(self.num_migration_cache_blocks, self.migration_num_layers, 2, self.migration_cache_size),
dtype=self.cache_engine.dtype,
device=self.cache_device,
pin_memory=pin_memory
)

self.migration_stream = cupy.cuda.Stream()
cache_shape = (self.num_migration_cache_blocks, self.migration_num_layers, 2, self.migration_cache_size)
super().__init__(migration_config.migration_internal_cache_size, cache_shape, self.cache_engine.dtype,
self.cache_device, pin_memory=pin_memory)

def init_backend(self, group_name, world_size, rank) -> bool:
@func_set_timeout(self.migration_config.migration_backend_init_timeout)
Expand Down Expand Up @@ -250,7 +245,8 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]

def do_send(self, dst_handle, blocks: List[int]):
num_blocks = len(blocks)
send_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
dummy_cache_idx = self.get_available_cache()
send_cache = self.dummy_cache[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)}

with self.migration_stream:
Expand All @@ -261,11 +257,13 @@ def do_send(self, dst_handle, blocks: List[int]):
# TODO(KuilongCui): check the error code if peer is dead
col.send(send_cache, dst_handle, self.group_name)
self.migration_stream.synchronize()
self.put_back_cache(dummy_cache_idx)

def do_recv(self, src_handle, blocks: List[int]):
num_blocks = len(blocks)
src_to_dst = dict(enumerate(blocks))
recv_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
dummy_cache_idx = self.get_available_cache()
recv_cache = self.dummy_cache[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)

with self.migration_stream:
for layer_idx in range(self.cache_engine.num_layers):
Expand All @@ -274,6 +272,7 @@ def do_recv(self, src_handle, blocks: List[int]):
col.recv(recv_cache, src_handle, self.group_name)
self.cache_engine.attn_backend.swap_blocks(recv_cache[cache_idx], self.gpu_cache[layer_idx], src_to_dst)
self.migration_stream.synchronize()
self.put_back_cache(dummy_cache_idx)

def get_migration_backend(migration_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy,
is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase:
Expand All @@ -284,6 +283,8 @@ def get_migration_backend(migration_config: MigrationConfig, cache_engine: Cache

target_col = None
backend = migration_config.migration_backend
assert backend in ['nccl', 'gloo', 'rpc'], "Unsupported backend: {} for VLLM".format(backend)

if backend in ['nccl', 'gloo']:
target_col = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy,
is_driver_worker, gpu_cache)
Expand Down
7 changes: 4 additions & 3 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def reserve_memory_for_migration(self, migration_config: MigrationConfig, model_
cache_config: CacheConfig, parallel_config: ParallelConfig) -> int:
migrate_cache_blocks_size = migration_config.migration_cache_blocks
migrate_num_layers = migration_config.migration_num_layers
dummy_cache_size = migrate_num_layers * migrate_cache_blocks_size * CacheEngine.get_cache_block_size(
cache_config, model_config, parallel_config) // model_config.get_num_layers(parallel_config)
dummy_cache_size = migration_config.migration_internal_cache_size * migrate_num_layers * migrate_cache_blocks_size \
* CacheEngine.get_cache_block_size(cache_config, model_config, parallel_config) \
// model_config.get_num_layers(parallel_config)

# For nccl migration backend, reserve gpu memory for dummy cache in migration backend. For other backends,
# CPU memory is used for the dummy cache, which is almost unlimited, so no special action is needed.
Expand Down Expand Up @@ -118,7 +119,7 @@ def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_block
total_kv_cache_size = len(src_blocks) * CacheEngine.get_cache_block_size(
self.cache_config, self.model_config, self.parallel_config)
speed = total_kv_cache_size/_GB/(end_time - start_time)
logger.info("[migration_cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s."
logger.info("[migrate cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s."
.format(len(src_blocks), convert_bytes(total_kv_cache_size), end_time-start_time, speed))

def do_recv(self, *args, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@
_C.MANAGER.MIGRATION_CACHE_BLOCKS = 512
# Number of kv-cache layers to transfer in each round during migration
_C.MANAGER.MIGRATION_NUM_LAYERS = 1
# Number of internal cache size in migration backend for sending and receiving
_C.MANAGER.MIGRATION_INTERNAL_CACHE_SIZE = 2

# -----------------------------------------------------------------------------
# SCALING CONFIGURATION
Expand Down
4 changes: 4 additions & 0 deletions llumnix/global_scheduler/dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def remove_instance(self, instance_id: str) -> None:
if instance_id in self.available_dispatch_instance_set:
self.available_dispatch_instance_set.remove(instance_id)

if self.num_instances >= self.num_dispatch_instances:
free_instance_id = next(iter(self.instance_id_set - self.available_dispatch_instance_set))
self.available_dispatch_instance_set.add(free_instance_id)

def _sort_instance_infos(self,
descending: bool = True) -> None:
instance_infos: List[InstanceInfo] = list(self.instance_info.values())
Expand Down
4 changes: 2 additions & 2 deletions llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def dispatch(self) -> str:
request_expected_steps = 1 if self.enable_pd_disagg else math.inf
return instance_id, request_expected_steps

def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type: PairMigrationConstraints, inflight_migrating: Dict[str, int]) -> List[Tuple[str, str]]:
self.migration_scheduler.update_instance_infos(self.instance_info)
migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type)
migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type, inflight_migrating)
return migrate_instance_pairs

def check_scale(self) -> Tuple[str, str]:
Expand Down
Loading

0 comments on commit 41829fd

Please sign in to comment.