Skip to content

Commit

Permalink
Revert "[Core] Support one-to-many and many-to-one migration (#63)"
Browse files Browse the repository at this point in the history
This reverts commit 844c836.
  • Loading branch information
s5u13b committed Dec 9, 2024
1 parent 17e4c23 commit c17cec4
Show file tree
Hide file tree
Showing 16 changed files with 184 additions and 329 deletions.
3 changes: 1 addition & 2 deletions configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ MANAGER:
REQUEST_MIGRATION_POLICY: 'SR'

MIGRATION_BACKEND: 'gloo'
MIGRATION_BUFFER_BLOCKS: 512
MIGRATION_INTERNAL_BUFFER_NUM: 2
MIGRATION_CACHE_BLOCKS: 512

ENABLE_SCALING: False
13 changes: 4 additions & 9 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--profiling-result-file-path PROFILING_RESULT_FILE_PATH]
[--gpu-type GPU_TYPE]
[--polling-interval POLLING_INTERVAL]
[--migration-backend {gloo,rpc}]
[--migration-buffer-blocks MIGRATION_BUFFER_BLOCKS]
[--migration-backend {gloo,nccl,rpc}]
[--migration-cache-blocks MIGRATION_CACHE_BLOCKS]
[--migration-backend-init-timeout MIGRATION_BACKEND_INIT_TIMEOUT]
[--migration-num-layers MIGRATION_NUM_LAYERS]
[--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS]
[--max-stages MAX_STAGES]
[--enable-pd-disagg]
[--num-dispatch-instances NUM_DISPATCH_INSTANCES]
[--migration-internal-buffer-num MIGRATION_INTERNAL_BUFFER_NUM]
[--log-request-timestamps]
```
Expand Down Expand Up @@ -148,8 +147,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Possible choices: gloo, rpc
- Default: "rpc"

`--migration-buffer-blocks`
- Number of cache blocks in each migration buffer.
`--migration-cache-blocks`
- Number of cache blocks in migration.
- Default: 512

`--migration-backend-init-timeout`
Expand All @@ -168,10 +167,6 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Drop migration if the number of stages > max_stages.
- Default: 3

`--migration-internal-buffer-num`
- Number of the buffer in migration backend for sending and receiving
- Default: 2

`--log-request-timestamps`
- Enable logging request timestamps.

Expand Down
18 changes: 5 additions & 13 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,10 @@ class EngineManagerArgs:

migration_backend_init_timeout: float = None
migration_backend: str = None
migration_buffer_blocks: int = None
migration_cache_blocks: int = None
migration_num_layers: int = None
last_stage_max_blocks: int = None
max_stages: int = None
migration_internal_buffer_num: int = None

enable_pd_disagg: bool = None

Expand Down Expand Up @@ -173,12 +172,11 @@ def create_global_scheduler_configs(
def create_migration_config(self) -> MigrationConfig:
migration_config = MigrationConfig(self.request_migration_policy,
self.migration_backend,
self.migration_buffer_blocks,
self.migration_cache_blocks,
self.migration_num_layers,
self.last_stage_max_blocks,
self.max_stages,
self.migration_backend_init_timeout,
self.migration_internal_buffer_num)
self.migration_backend_init_timeout)
return migration_config

@classmethod
Expand All @@ -197,9 +195,6 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser):
if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest):
assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}."

assert args.migration_backend != 'nccl', 'NCCL has been temporarily deprecated due to its incompatibility with \
concurrent migrations in Llumnix.'

assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \
and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \
("When using gloo as migration backend, "
Expand Down Expand Up @@ -314,18 +309,15 @@ def add_cli_args(
parser.add_argument('--migration-backend-init-timeout',
type=float,
help='timeout(s) for initializing migration backend')
parser.add_argument('--migration-buffer-blocks',
parser.add_argument('--migration-cache-blocks',
type=int,
help='number of cache blocks in each migration buffer')
help='number of cache blocks in migration')
parser.add_argument('--migration-num-layers',
type=int,
help='number of kv-cache layers to transfer in each round during migration')
parser.add_argument('--last-stage-max-blocks',
type=int,
help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration')
parser.add_argument('--migration-internal-buffer-num',
type=int,
help='number of the buffer in migration backend for sending and receiving')
parser.add_argument('--max-stages',
type=int,
help='drop migration if the number of stages > max_stages')
Expand Down
23 changes: 0 additions & 23 deletions llumnix/backends/migration_backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

from abc import ABC, abstractmethod
from typing import List
import queue

import torch

class MigrationBackendBase(ABC):
@abstractmethod
Expand All @@ -41,24 +39,3 @@ def do_send(self, dst_handle, blocks: List[int]):
@abstractmethod
def do_recv(self, src_handle, blocks: List[int]):
raise NotImplementedError

class BufferMigrationBackend(MigrationBackendBase):
def __init__(self, num_buffer, buffer_shape, buffer_dtype, buffer_device, pin_memory, *args, **kwargs):
super().__init__(*args, **kwargs)

self.num_buffer = num_buffer

self.dummy_buffer = [
torch.empty(size=buffer_shape, dtype=buffer_dtype, device=buffer_device, pin_memory=pin_memory)
for _ in range(self.num_buffer)
]

self.avaiable_buffer_queue = queue.Queue()
for i in range(self.num_buffer):
self.avaiable_buffer_queue.put_nowait(i)

def get_available_cache(self):
return self.avaiable_buffer_queue.get()

def put_back_cache(self, buffer_id):
self.avaiable_buffer_queue.put_nowait(buffer_id)
89 changes: 42 additions & 47 deletions llumnix/backends/vllm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
import torch
from func_timeout import func_set_timeout, FunctionTimedOut

import cupy
from cupy.cuda import nccl
import ray
import ray.util.collective as col
from ray.util.collective.collective_group import nccl_util

from vllm.worker.cache_engine import CacheEngine
from llumnix.internal_config import MigrationConfig
from llumnix.backends.migration_backend_interface import MigrationBackendBase, BufferMigrationBackend
from llumnix.backends.migration_backend_interface import MigrationBackendBase
from llumnix.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -44,16 +40,17 @@ def exec_method(self, is_driver_worker, handle, *args, **kwargs):

NUMPY_SUPPORTED_DTYPES = [torch.float32, torch.float16]

class RayRpcMigrationBackend(BufferMigrationBackend):
class RayRpcMigrationBackend(MigrationBackendBase):
def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, worker_rank, worker_handle_list, \
scheduling_strategy, is_driver_worker, gpu_cache) -> None:
super().__init__()

self.migration_config = migration_config
self.cache_engine = cache_engine

self.worker_rank = worker_rank
self.worker_handle_list = worker_handle_list
self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote()
self.migration_stream = torch.cuda.Stream()

self.rpc_dtype = self.cache_engine.dtype
if self.cache_engine.dtype in NUMPY_SUPPORTED_DTYPES:
Expand All @@ -65,13 +62,17 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache
self.cache_device = "cpu"
self.num_migration_buffer_blocks = self.migration_config.migration_buffer_blocks
self.num_migration_cache_blocks = self.migration_config.migration_cache_blocks
self.num_layers = self.cache_engine.num_layers
self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size
buffer_shape = (self.num_migration_buffer_blocks, self.num_layers, 2, self.migration_cache_size)

super().__init__(migration_config.migration_internal_buffer_num, buffer_shape, self.cache_engine.dtype,
self.cache_device, pin_memory=True)
self.dummy_cache = torch.empty(
size=(self.num_migration_cache_blocks, self.num_layers, 2, self.migration_cache_size),
dtype=self.cache_engine.dtype,
device=self.cache_device,
pin_memory=True
)
self.migration_stream = torch.cuda.Stream()

def init_backend(self, group_name, world_size, rank) -> bool:
logger.info("create rpc migration backend successfully.")
Expand All @@ -93,46 +94,37 @@ def warmup(self) -> bool:
def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None:
tot_blocks = len(src_blocks)
rpc_numpy_cache = None
for start_idx in range(0, tot_blocks, self.num_migration_buffer_blocks):
offset = min(self.num_migration_buffer_blocks, tot_blocks - start_idx)
for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks):
offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx)
send_blocks = src_blocks[start_idx:start_idx+offset]
ray_obj = self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", None, send_blocks)
if rpc_numpy_cache is not None:
self.do_recv(rpc_numpy_cache, recv_blocks)
rpc_numpy_cache_ref = ray.get(ray_obj)
rpc_numpy_cache = ray.get(rpc_numpy_cache_ref)
rpc_numpy_cache = ray.get(ray_obj)
recv_blocks = dst_blocks[start_idx:start_idx+offset]
self.do_recv(rpc_numpy_cache, recv_blocks)

def do_send(self, dst_handle, blocks: List[int]):
num_blocks = len(blocks)
dummy_cache_idx = self.get_available_cache()
send_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
send_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)}
with torch.cuda.stream(self.migration_stream):
for layer_idx in range(self.num_layers):
self.cache_engine.attn_backend.swap_blocks(self.gpu_cache[layer_idx], send_cache[layer_idx], src_to_dst)
torch.cuda.Stream.synchronize(self.migration_stream)
# Here, we use ray.put to store data and finally return the object reference so that we can release the internal buffer.
# This might seem like an anti-pattern, but it's okay since the kv-cache transferred is in the MB range and won't utilize
# Ray's optimization for returning small objects (<100KB).
data = ray.put(send_cache.to(self.rpc_dtype).numpy())
self.put_back_cache(dummy_cache_idx)
return data
return send_cache.to(self.rpc_dtype).numpy()

def do_recv(self, src_handle, blocks: List[int]):
num_blocks = len(blocks)
src_to_dst = dict(enumerate(blocks))
dummy_cache_idx = self.get_available_cache()
recv_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
recv_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
# use pin memory dummy_cache to speed up data transfer
recv_cache.copy_(torch.from_numpy(src_handle))

with torch.cuda.stream(self.migration_stream):
for layer_idx in range(self.num_layers):
self.cache_engine.attn_backend.swap_blocks(recv_cache[layer_idx], self.gpu_cache[layer_idx], src_to_dst)
torch.cuda.Stream.synchronize(self.migration_stream)
self.put_back_cache(dummy_cache_idx)

def try_import_gloo():
try:
Expand All @@ -147,14 +139,19 @@ def try_import_gloo():
except ImportError as e:
raise ImportError("Gloo is not installed. Please install it first.") from e

class RayColMigrationBackend(BufferMigrationBackend):
class RayColMigrationBackend(MigrationBackendBase):
def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, local_rank,
scheduling_strategy, is_driver_worker, gpu_cache) -> None:
super().__init__()

# pylint: disable=C0415
import cupy

self.migration_config = migration_config
self.cache_engine = cache_engine
self.backend = migration_config.migration_backend
self.migration_num_layers = min(migration_config.migration_num_layers, self.cache_engine.num_layers)
self.num_migration_buffer_blocks = migration_config.migration_buffer_blocks
self.num_migration_cache_blocks = migration_config.migration_cache_blocks

self.backend = migration_config.migration_backend
self.global_world_size = -1
Expand All @@ -165,21 +162,24 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote()
self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache
self.migration_stream = cupy.cuda.Stream()

self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size

if self.backend == 'gloo':
try_import_gloo()
self.cache_device = "cpu"
else:
nccl_util.TORCH_NCCL_DTYPE_MAP[torch.bfloat16] = nccl.NCCL_FLOAT16
self.cache_device = torch.device(f"cuda:{self.local_rank}")

pin_memory = (self.backend == 'gloo')
buffer_shape = (self.num_migration_buffer_blocks, self.migration_num_layers, 2, self.migration_cache_size)
super().__init__(migration_config.migration_internal_buffer_num, buffer_shape, self.cache_engine.dtype,
self.cache_device, pin_memory=pin_memory)
self.dummy_cache = torch.empty(
size=(self.num_migration_cache_blocks, self.migration_num_layers, 2, self.migration_cache_size),
dtype=self.cache_engine.dtype,
device=self.cache_device,
pin_memory=pin_memory
)

self.migration_stream = cupy.cuda.Stream()

def init_backend(self, group_name, world_size, rank) -> bool:
@func_set_timeout(self.migration_config.migration_backend_init_timeout)
Expand Down Expand Up @@ -224,7 +224,7 @@ def destory_backend(self) -> None:
def warmup(self) -> bool:
if self.global_world_size > 1:
try:
col.allreduce(self.dummy_buffer[0][0], self.group_name)
col.allreduce(self.dummy_cache[0], self.group_name)
# pylint: disable=W0703
except Exception as e:
logger.info("warmup migration backend failed (group_name: {}, world_size: {}, rank: {}, backbend: {}), err: {}."
Expand All @@ -241,17 +241,16 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
tot_blocks = len(src_blocks)
src_rank = ray.get(self.actor.exec_method.remote(self.is_driver_worker, src_handle, "get_global_rank"))

for start_idx in range(0, tot_blocks, self.num_migration_buffer_blocks):
offset = min(self.num_migration_buffer_blocks, tot_blocks - start_idx)
for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks):
offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx)
send_blocks = src_blocks[start_idx:start_idx+offset]
recv_blocks = dst_blocks[start_idx:start_idx+offset]
self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", self.global_rank, send_blocks)
self.do_recv(src_rank, recv_blocks)

def do_send(self, dst_handle, blocks: List[int]):
num_blocks = len(blocks)
dummy_cache_idx = self.get_available_cache()
send_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
send_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)}

with self.migration_stream:
Expand All @@ -262,13 +261,11 @@ def do_send(self, dst_handle, blocks: List[int]):
# TODO(KuilongCui): check the error code if peer is dead
col.send(send_cache, dst_handle, self.group_name)
self.migration_stream.synchronize()
self.put_back_cache(dummy_cache_idx)

def do_recv(self, src_handle, blocks: List[int]):
num_blocks = len(blocks)
src_to_dst = dict(enumerate(blocks))
dummy_cache_idx = self.get_available_cache()
recv_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
recv_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)

with self.migration_stream:
for layer_idx in range(self.cache_engine.num_layers):
Expand All @@ -277,18 +274,16 @@ def do_recv(self, src_handle, blocks: List[int]):
col.recv(recv_cache, src_handle, self.group_name)
self.cache_engine.attn_backend.swap_blocks(recv_cache[cache_idx], self.gpu_cache[layer_idx], src_to_dst)
self.migration_stream.synchronize()
self.put_back_cache(dummy_cache_idx)

def get_migration_backend(migration_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy,
is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase:
if cache_engine.num_gpu_blocks < migration_config.migration_buffer_blocks:
logger.warning("migration_buffer_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks."
.format(migration_config.migration_buffer_blocks, cache_engine.num_gpu_blocks))
migration_config.migration_buffer_blocks = cache_engine.num_gpu_blocks
if cache_engine.num_gpu_blocks < migration_config.migration_cache_blocks:
logger.warning("migration_cache_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks."
.format(migration_config.migration_cache_blocks, cache_engine.num_gpu_blocks))
migration_config.migration_cache_blocks = cache_engine.num_gpu_blocks

target_migration_backend = None
backend = migration_config.migration_backend
assert backend in ['nccl', 'gloo', 'rpc'], "Unsupported migration backend: {} for llumnix".format(backend)

if backend in ['nccl', 'gloo']:
target_migration_backend = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy,
Expand Down
Loading

0 comments on commit c17cec4

Please sign in to comment.