Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Nov 6, 2024
1 parent 102fb86 commit 923fe8d
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 70 deletions.
4 changes: 2 additions & 2 deletions configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ MANAGER:
REQUEST_MIGRATION_POLICY: 'SJF'

MIGRATION_BACKEND: 'gloo'
MIGRATION_CACHE_BLOCKS: 512
MIGRATION_INTERNAL_CACHE_SIZE: 1
MIGRATION_BUFFER_BLOCKS: 512
MIGRATION_INTERNAL_BUFFER_NUM: 2

ENABLE_SCALING: False
12 changes: 6 additions & 6 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--gpu-type GPU_TYPE]
[--polling-interval POLLING_INTERVAL]
[--migration-backend {gloo,nccl,rpc}]
[--migration-cache-blocks MIGRATION_CACHE_BLOCKS]
[--migration-buffer-blocks MIGRATION_BUFFER_BLOCKS]
[--migration-backend-init-timeout MIGRATION_BACKEND_INIT_TIMEOUT]
[--migration-num-layers MIGRATION_NUM_LAYERS]
[--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS]
[--max-stages MAX_STAGES]
[--enable-pd-disagg]
[--num-dispatch-instances NUM_DISPATCH_INSTANCES]
[--migration-internal-cache-size MIGRATION_INTERNAL_CACHE_SIZE]
[--migration-internal-buffer-num MIGRATION_INTERNAL_BUFFER_NUM]
[--log-request-timestamps]
```
Expand Down Expand Up @@ -148,8 +148,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Possible choices: gloo, rpc
- Default: "rpc"

`--migration-cache-blocks`
- Number of cache blocks in migration.
`--migration-buffer-blocks`
- Number of cache blocks in each migration buffer.
- Default: 512

`--migration-backend-init-timeout`
Expand All @@ -168,8 +168,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Drop migration if the number of stages > max_stages.
- Default: 3

`--migration-internal-cache-size`
- Number of internal cache size in migration backend for sending and receiving
`--migration-internal-buffer-num`
- Number of the buffer in migration backend for sending and receiving
- Default: 2

`--log-request-timestamps`
Expand Down
16 changes: 8 additions & 8 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ class EngineManagerArgs:

migration_backend_init_timeout: float = None
migration_backend: str = None
migration_cache_blocks: int = None
migration_buffer_blocks: int = None
migration_num_layers: int = None
last_stage_max_blocks: int = None
max_stages: int = None
migration_internal_cache_size: int = None
migration_internal_buffer_num: int = None

enable_pd_disagg: bool = None

Expand Down Expand Up @@ -175,12 +175,12 @@ def create_global_scheduler_configs(
def create_migration_config(self) -> MigrationConfig:
migration_config = MigrationConfig(self.request_migration_policy,
self.migration_backend,
self.migration_cache_blocks,
self.migration_buffer_blocks,
self.migration_num_layers,
self.last_stage_max_blocks,
self.max_stages,
self.migration_backend_init_timeout,
self.migration_internal_cache_size)
self.migration_internal_buffer_num)
return migration_config

@classmethod
Expand Down Expand Up @@ -297,18 +297,18 @@ def add_cli_args(
parser.add_argument('--migration-backend-init-timeout',
type=float,
help='timeout(s) for initializing migration backend')
parser.add_argument('--migration-cache-blocks',
parser.add_argument('--migration-buffer-blocks',
type=int,
help='number of cache blocks in migration')
help='number of cache blocks in each migration buffer')
parser.add_argument('--migration-num-layers',
type=int,
help='number of kv-cache layers to transfer in each round during migration')
parser.add_argument('--last-stage-max-blocks',
type=int,
help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration')
parser.add_argument('--migration-internal-cache-size',
parser.add_argument('--migration-internal-buffer-num',
type=int,
help='number of internal cache size in migration backend for sending and receiving')
help='number of the buffer in migration backend for sending and receiving')
parser.add_argument('--max-stages',
type=int,
help='drop migration if the number of stages > max_stages')
Expand Down
20 changes: 10 additions & 10 deletions llumnix/backends/migration_backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,23 @@ def do_send(self, dst_handle, blocks: List[int]):
def do_recv(self, src_handle, blocks: List[int]):
raise NotImplementedError

class CacheMigrationBackend(MigrationBackendBase):
def __init__(self, num_cache, cache_shape, cache_dtype, cache_device, pin_memory, *args, **kwargs):
class BufferMigrationBackend(MigrationBackendBase):
def __init__(self, num_buffer, cache_shape, cache_dtype, cache_device, pin_memory, *args, **kwargs):
super().__init__(*args, **kwargs)

self.num_cache = num_cache
self.num_buffer = num_buffer

self.dummy_cache = [
self.dummy_buffer = [
torch.empty(size=cache_shape, dtype=cache_dtype, device=cache_device, pin_memory=pin_memory)
for _ in range(self.num_cache)
for _ in range(self.num_buffer)
]

self.avaiable_cache_queue = queue.Queue()
for i in range(self.num_cache):
self.avaiable_cache_queue.put_nowait(i)
self.avaiable_buffer_queue = queue.Queue()
for i in range(self.num_buffer):
self.avaiable_buffer_queue.put_nowait(i)

def get_available_cache(self):
return self.avaiable_cache_queue.get()
return self.avaiable_buffer_queue.get()

def put_back_cache(self, cache_id):
self.avaiable_cache_queue.put_nowait(cache_id)
self.avaiable_buffer_queue.put_nowait(cache_id)
13 changes: 8 additions & 5 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,14 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
self.add_running_request(backend_request)

async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
await dst_ray_actor.execute_engine_method.remote("_run_workers",
"migrate_cache",
dst_blocks=dst_blocks,
src_blocks=src_blocks,
src_worker_handle_list=self.worker_handle_list)
try:
await dst_ray_actor.execute_engine_method.remote("_run_workers",
"migrate_cache",
dst_blocks=dst_blocks,
src_blocks=src_blocks,
src_worker_handle_list=self.worker_handle_list)
except Exception as e:
logger.error(f"Error in migration backend: {e}")

def _run_workers(self, *args, **kwargs):
# pylint: disable=protected-access
Expand Down
44 changes: 22 additions & 22 deletions llumnix/backends/vllm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from vllm.worker.cache_engine import CacheEngine
from llumnix.internal_config import MigrationConfig
from llumnix.backends.migration_backend_interface import MigrationBackendBase, CacheMigrationBackend
from llumnix.backends.migration_backend_interface import MigrationBackendBase, BufferMigrationBackend
from llumnix.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -44,7 +44,7 @@ def exec_method(self, is_driver_worker, handle, *args, **kwargs):

NUMPY_SUPPORTED_DTYPES = [torch.float32, torch.float16]

class RayRpcMigrationBackend(CacheMigrationBackend):
class RayRpcMigrationBackend(BufferMigrationBackend):
def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, worker_rank, worker_handle_list, \
scheduling_strategy, is_driver_worker, gpu_cache) -> None:
self.migration_config = migration_config
Expand All @@ -65,12 +65,12 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache
self.cache_device = "cpu"
self.num_migration_cache_blocks = self.migration_config.migration_cache_blocks
self.num_migration_buffer_blocks = self.migration_config.migration_buffer_blocks
self.num_layers = self.cache_engine.num_layers
self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size
cache_shape = (self.num_migration_cache_blocks, self.num_layers, 2, self.migration_cache_size)
cache_shape = (self.num_migration_buffer_blocks, self.num_layers, 2, self.migration_cache_size)

super().__init__(migration_config.migration_internal_cache_size, cache_shape, self.cache_engine.dtype,
super().__init__(migration_config.migration_internal_buffer_num, cache_shape, self.cache_engine.dtype,
self.cache_device, pin_memory=True)

def init_backend(self, group_name, world_size, rank) -> bool:
Expand All @@ -93,8 +93,8 @@ def warmup(self) -> bool:
def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None:
tot_blocks = len(src_blocks)
rpc_numpy_cache = None
for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks):
offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx)
for start_idx in range(0, tot_blocks, self.num_migration_buffer_blocks):
offset = min(self.num_migration_buffer_blocks, tot_blocks - start_idx)
send_blocks = src_blocks[start_idx:start_idx+offset]
ray_obj = self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", None, send_blocks)
if rpc_numpy_cache is not None:
Expand All @@ -107,7 +107,7 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
def do_send(self, dst_handle, blocks: List[int]):
num_blocks = len(blocks)
dummy_cache_idx = self.get_available_cache()
send_cache = self.dummy_cache[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
send_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)}
with torch.cuda.stream(self.migration_stream):
for layer_idx in range(self.num_layers):
Expand All @@ -121,7 +121,7 @@ def do_recv(self, src_handle, blocks: List[int]):
num_blocks = len(blocks)
src_to_dst = dict(enumerate(blocks))
dummy_cache_idx = self.get_available_cache()
recv_cache = self.dummy_cache[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
recv_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size)
# use pin memory dummy_cache to speed up data transfer
recv_cache.copy_(torch.from_numpy(src_handle))

Expand All @@ -144,14 +144,14 @@ def try_import_gloo():
except ImportError as e:
raise ImportError("Gloo is not installed. Please install it first.") from e

class RayColMigrationBackend(CacheMigrationBackend):
class RayColMigrationBackend(BufferMigrationBackend):
def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, local_rank,
scheduling_strategy, is_driver_worker, gpu_cache) -> None:
self.migration_config = migration_config
self.cache_engine = cache_engine
self.backend = migration_config.migration_backend
self.migration_num_layers = min(migration_config.migration_num_layers, self.cache_engine.num_layers)
self.num_migration_cache_blocks = migration_config.migration_cache_blocks
self.num_migration_buffer_blocks = migration_config.migration_buffer_blocks

self.backend = migration_config.migration_backend
self.global_world_size = -1
Expand All @@ -174,8 +174,8 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
self.cache_device = torch.device(f"cuda:{self.local_rank}")

pin_memory = (self.backend == 'gloo')
cache_shape = (self.num_migration_cache_blocks, self.migration_num_layers, 2, self.migration_cache_size)
super().__init__(migration_config.migration_internal_cache_size, cache_shape, self.cache_engine.dtype,
cache_shape = (self.num_migration_buffer_blocks, self.migration_num_layers, 2, self.migration_cache_size)
super().__init__(migration_config.migration_internal_buffer_num, cache_shape, self.cache_engine.dtype,
self.cache_device, pin_memory=pin_memory)

def init_backend(self, group_name, world_size, rank) -> bool:
Expand Down Expand Up @@ -221,7 +221,7 @@ def destory_backend(self) -> None:
def warmup(self) -> bool:
if self.global_world_size > 1:
try:
col.allreduce(self.dummy_cache[0], self.group_name)
col.allreduce(self.dummy_buffer[0], self.group_name)
# pylint: disable=W0703
except Exception as e:
logger.info("warmup migration backend failed (group_name: {}, world_size: {}, rank: {}, backbend: {}), err: {}."
Expand All @@ -238,8 +238,8 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
tot_blocks = len(src_blocks)
src_rank = ray.get(self.actor.exec_method.remote(self.is_driver_worker, src_handle, "get_global_rank"))

for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks):
offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx)
for start_idx in range(0, tot_blocks, self.num_migration_buffer_blocks):
offset = min(self.num_migration_buffer_blocks, tot_blocks - start_idx)
send_blocks = src_blocks[start_idx:start_idx+offset]
recv_blocks = dst_blocks[start_idx:start_idx+offset]
self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", self.global_rank, send_blocks)
Expand All @@ -248,7 +248,7 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]
def do_send(self, dst_handle, blocks: List[int]):
num_blocks = len(blocks)
dummy_cache_idx = self.get_available_cache()
send_cache = self.dummy_cache[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
send_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)}

with self.migration_stream:
Expand All @@ -265,7 +265,7 @@ def do_recv(self, src_handle, blocks: List[int]):
num_blocks = len(blocks)
src_to_dst = dict(enumerate(blocks))
dummy_cache_idx = self.get_available_cache()
recv_cache = self.dummy_cache[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
recv_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)

with self.migration_stream:
for layer_idx in range(self.cache_engine.num_layers):
Expand All @@ -278,10 +278,10 @@ def do_recv(self, src_handle, blocks: List[int]):

def get_migration_backend(migration_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy,
is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase:
if cache_engine.num_gpu_blocks < migration_config.migration_cache_blocks:
logger.warning("migration_cache_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks."
.format(migration_config.migration_cache_blocks, cache_engine.num_gpu_blocks))
migration_config.migration_cache_blocks = cache_engine.num_gpu_blocks
if cache_engine.num_gpu_blocks < migration_config.migration_buffer_blocks:
logger.warning("migration_buffer_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks."
.format(migration_config.migration_buffer_blocks, cache_engine.num_gpu_blocks))
migration_config.migration_buffer_blocks = cache_engine.num_gpu_blocks

target_col = None
backend = migration_config.migration_backend
Expand Down
8 changes: 4 additions & 4 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def get_global_rank(self):

def reserve_memory_for_migration(self, migration_config: MigrationConfig, model_config: ModelConfig,
cache_config: CacheConfig, parallel_config: ParallelConfig) -> int:
migrate_cache_blocks_size = migration_config.migration_cache_blocks
migrate_cache_blocks_size = migration_config.migration_buffer_blocks
migrate_num_layers = migration_config.migration_num_layers
dummy_cache_size = migration_config.migration_internal_cache_size * migrate_num_layers * migrate_cache_blocks_size \
dummy_cache_size = migration_config.migration_internal_buffer_num * migrate_num_layers * migrate_cache_blocks_size \
* CacheEngine.get_cache_block_size(cache_config, model_config, parallel_config) \
// model_config.get_num_layers(parallel_config)

Expand Down Expand Up @@ -113,14 +113,14 @@ def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_block
self.migration_backend.migrate_cache(src_worker_handle, src_blocks, dst_blocks)
# pylint: disable=broad-except
except Exception as e:
logger.info("[migrate_cache] self.rank: {}, src_worker_handle {}, meet err : {}"
logger.info("[migrate_cache] self.rank: {}, src_worker_handle {}, meet error : {}"
.format(self.rank, src_worker_handle, e))
end_time = time.time()

total_kv_cache_size = len(src_blocks) * CacheEngine.get_cache_block_size(
self.cache_config, self.model_config, self.parallel_config)
speed = total_kv_cache_size/_GB/(end_time - start_time)
logger.info("[migrate cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s."
logger.info("[migrate_cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s."
.format(len(src_blocks), convert_bytes(total_kv_cache_size), end_time-start_time, speed))

def do_recv(self, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@
# Timeout(s) for initializing migration backend
_C.MANAGER.MIGRATION_BACKEND_INIT_TIMEOUT = 10.0
# Number of cache blocks in migration
_C.MANAGER.MIGRATION_CACHE_BLOCKS = 512
_C.MANAGER.MIGRATION_BUFFER_BLOCKS = 512
# Number of kv-cache layers to transfer in each round during migration
_C.MANAGER.MIGRATION_NUM_LAYERS = 1
# Number of internal cache size in migration backend for sending and receiving
_C.MANAGER.MIGRATION_INTERNAL_CACHE_SIZE = 2
_C.MANAGER.MIGRATION_INTERNAL_BUFFER_NUM = 2

# -----------------------------------------------------------------------------
# SCALING CONFIGURATION
Expand Down
Loading

0 comments on commit 923fe8d

Please sign in to comment.