Skip to content

Commit

Permalink
[Refactor] refactor migration scheduler (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui authored Nov 15, 2024
1 parent 8e52054 commit bcd49ba
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 201 deletions.
3 changes: 2 additions & 1 deletion llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from llumnix.internal_config import GlobalSchedulerConfig
from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo
from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler
from llumnix.global_scheduler.migration_policy import PairMigrationConstraints
from llumnix.global_scheduler.scaling_scheduler import ScalingScheduler

logger = init_logger(__name__)
Expand Down
149 changes: 149 additions & 0 deletions llumnix/global_scheduler/migration_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, List, Optional
from abc import ABC, abstractmethod

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.global_scheduler.scaling_scheduler import InstanceType
from llumnix.global_scheduler.migration_policy import PairMigrationConstraints

logger = init_logger(__name__)

class MigrationFilterConfig:
def __init__(self, migrate_out_load_threshold):
self.migrate_out_load_threshold: float = migrate_out_load_threshold

# TODO(KuilongCui): A filter might contain other filters; leave this for the future.
class MigrationFilterPolicy(ABC):
@abstractmethod
def filter_src_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]:
raise NotImplementedError

@abstractmethod
def filter_dst_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]:
raise NotImplementedError

class MigrationInstanceFilter(ABC):
def __init__(self, filter_config: MigrationFilterConfig) -> None:
self.filter_config = filter_config
self.registered_filters: Dict[str, MigrationFilterPolicy] = {}

def register_filter(self, filter_name: str, migration_filter: MigrationFilterPolicy) -> bool:
if filter_name in self.registered_filters:
logger.warning("migration filter {} has been registered.".format(filter_name))
return False

self.registered_filters[filter_name] = migration_filter
return True

def unregister_filter(self, filter_name: str) -> None:
self.registered_filters.pop(filter_name, None)

def get_filter(self, filter_name: str) -> Optional[MigrationFilterPolicy]:
return self.registered_filters.get(filter_name, None)

def filter_instances(self, instance_infos: List[InstanceInfo],
pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]:
src_filter_conditions = [filter.filter_src_condition() for filter in self.registered_filters.values()]
dst_filter_conditions = [filter.filter_dst_condition() for filter in self.registered_filters.values()]

if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS:
policy_filter = MigrationFilterPolicyFactory.get_policy("load")
elif pair_migration_type in [PairMigrationConstraints.PREFILL_2_DECODING, PairMigrationConstraints.DECODING_2_DECODING]:
policy_filter = MigrationFilterPolicyFactory.get_policy('prefill_decode')
else:
raise ValueError(f"Unsupported pair migration type: {pair_migration_type}")
src_filter_conditions.append(policy_filter.filter_src_condition(self.filter_config, pair_migration_type))
dst_filter_conditions.append(policy_filter.filter_dst_condition(self.filter_config, pair_migration_type))

filtered_src_instance_infos = [info for info in instance_infos if all(cond(info) for cond in src_filter_conditions)]
filtered_dst_instance_infos = [info for info in instance_infos if all(cond(info) for cond in dst_filter_conditions)]

return filtered_src_instance_infos, filtered_dst_instance_infos

class LoadConstrainedFilter(MigrationFilterPolicy):
def filter_src_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return lambda instance_info: instance_info.num_killed_requests > 0 \
or instance_info.instance_load_migrate > filter_config.migrate_out_load_threshold

def filter_dst_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return lambda instance_info: instance_info.num_killed_requests == 0 \
and instance_info.instance_load_migrate < filter_config.migrate_out_load_threshold

class PddFilter(MigrationFilterPolicy):
INSTANCE_FILTER_RULES = {
PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE),
PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE),
}

def filter_src_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
src_type, _ = self.INSTANCE_FILTER_RULES[pair_migration_type]
instance_type_filter = lambda instance_info: instance_info.instance_type == src_type

if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING:
inner_policy = MigrationFilterPolicyFactory.get_policy('load')
policy_filter = inner_policy.filter_src_condition(filter_config, pair_migration_type)
else:
policy_filter = lambda instance_info: True

return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info)

def filter_dst_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
_, dst_type = self.INSTANCE_FILTER_RULES[pair_migration_type]
instance_type_filter = lambda instance_info: instance_info.instance_type == dst_type

if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING:
inner_policy = MigrationFilterPolicyFactory.get_policy('load')
policy_filter = inner_policy.filter_dst_condition(filter_config, pair_migration_type)
else:
policy_filter = lambda instance_info: instance_info.num_killed_requests == 0

return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info)

class CustomFilter(MigrationFilterPolicy):
def __init__(self):
super().__init__()
self.src_filter = lambda _: True
self.dst_filter = lambda _: True

def set_filter_condtition(self, src_filter: Optional[Callable[[InstanceInfo], bool]] = None,
dst_filter: Optional[Callable[[InstanceInfo], bool]] = None) -> None:
if src_filter:
self.src_filter = src_filter
if dst_filter:
self.dst_filter = dst_filter

def filter_src_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return self.src_filter

def filter_dst_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return self.dst_filter

class MigrationFilterPolicyFactory:
_POLICY_REGISTRY = {
'load': LoadConstrainedFilter,
'prefill_decode': PddFilter,
'custom': CustomFilter,
}

@classmethod
def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> MigrationFilterPolicy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
113 changes: 113 additions & 0 deletions llumnix/global_scheduler/migration_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple
from abc import ABC, abstractmethod
from enum import Enum
import copy
import numpy as np

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator

logger = init_logger(__name__)

class PairMigrationConstraints(str, Enum):
"""Target of Migration."""
NO_CONSTRAINTS = "NO_CONSTRAINTS"

# Enable the prefill-decoding disaggregration.
DECODING_2_DECODING = "DECODING_2_DECODING"
PREFILL_2_DECODING = "PREFILL_2_DECODING"

class PairMigrationPolicy(ABC):
def __init__(self,
migrate_out_load_threshold: float,
instance_load_calculator: InstanceLoadCalculator) -> None:
self.migrate_out_load_threshold = migrate_out_load_threshold
self.instance_load_calculator = instance_load_calculator

@abstractmethod
def pair_migration(self,
src_instance_infos: List[InstanceInfo],
dst_instance_infos: List[InstanceInfo],
) -> List[Tuple[str, str]]:
raise NotImplementedError

def sort_instance_infos(self, instance_infos: List[InstanceInfo], descending: bool = True) -> None:
key_attr = 'instance_load_migrate'
sorted_instance_infos = sorted(
instance_infos,
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)
return sorted_instance_infos

class Balanced(PairMigrationPolicy):
def pair_migration(self,
src_instance_infos: List[InstanceInfo],
dst_instance_infos: List[InstanceInfo],
) -> List[Tuple[str, str]]:
sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True)
sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False)
migrate_instance_pairs = []
for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))):
load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate

left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False)
right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True)

# Add some constrains to reduce unnecessary migrations
if right_load_after_mig > self.migrate_out_load_threshold:
continue
load_diff_after_mig = left_load_after_mig - right_load_after_mig
if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf):
migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id,
sorted_dst_instance_infos[i].instance_id))
return migrate_instance_pairs

def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float:
instance_info_after_migrate = copy.deepcopy(instance_info)
num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request

if is_migrate_in:
instance_info_after_migrate.num_running_requests += 1
instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request
else:
instance_info_after_migrate.num_running_requests -= 1
instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request

return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate')

class DefragConstrained(PairMigrationPolicy):
def pair_migration(self,
src_instance_infos: List[InstanceInfo],
dst_instance_infos: List[InstanceInfo],
) -> List[Tuple[str, str]]:
sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True)
sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False)
migrate_instance_pairs = []
for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))):
# without any constrain in order to make prefill migrate happens as soon as possible
migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id))
return migrate_instance_pairs

class PairMigrationPolicyFactory:
_POLICY_REGISTRY = {
'balanced': Balanced,
'defrag_constrained': DefragConstrained,
}

@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> PairMigrationPolicy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
Loading

0 comments on commit bcd49ba

Please sign in to comment.