-
-
Notifications
You must be signed in to change notification settings - Fork 722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consistent estimation of task duration between stealing, adaptive and occupancy calculation #9000
Changes from all commits
14ac478
78256d3
5b88656
49b916d
54fdd21
102dc19
fb8d8ec
f1243a3
2bd934e
40aa62a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1674,9 +1674,6 @@ | |
#: Subset of tasks that exist in memory on more than one worker | ||
replicated_tasks: set[TaskState] | ||
|
||
#: Tasks with unknown duration, grouped by prefix | ||
#: {task prefix: {ts, ts, ...}} | ||
unknown_durations: dict[str, set[TaskState]] | ||
task_groups: dict[str, TaskGroup] | ||
task_prefixes: dict[str, TaskPrefix] | ||
task_metadata: dict[Key, Any] | ||
|
@@ -1776,7 +1773,6 @@ | |
self.task_metadata = {} | ||
self.total_nthreads = 0 | ||
self.total_nthreads_history = [(time(), 0)] | ||
self.unknown_durations = {} | ||
self.queued = queued | ||
self.unrunnable = unrunnable | ||
self.validate = validate | ||
|
@@ -1855,7 +1851,6 @@ | |
"unrunnable": self.unrunnable, | ||
"queued": self.queued, | ||
"n_tasks": self.n_tasks, | ||
"unknown_durations": self.unknown_durations, | ||
"validate": self.validate, | ||
"tasks": self.tasks, | ||
"task_groups": self.task_groups, | ||
|
@@ -1907,7 +1902,6 @@ | |
self.task_prefixes, | ||
self.task_groups, | ||
self.task_metadata, | ||
self.unknown_durations, | ||
self.replicated_tasks, | ||
): | ||
collection.clear() | ||
|
@@ -1931,22 +1925,37 @@ | |
self._network_occ_global, | ||
) | ||
|
||
def _get_prefix_duration(self, prefix: TaskPrefix) -> float: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the single source of truth for the duration estimation. |
||
"""Get the estimated computation cost of the given task prefix | ||
(not including any communication cost). | ||
|
||
If no data has been observed, value of | ||
`distributed.scheduler.default-task-durations` are used. If none is set | ||
for this task, `distributed.scheduler.unknown-task-duration` is used | ||
instead. | ||
|
||
See Also | ||
-------- | ||
WorkStealing.get_task_duration | ||
""" | ||
# TODO: Deal with unknown tasks better | ||
assert prefix is not None | ||
duration = prefix.duration_average | ||
if duration < 0: | ||
if prefix.max_exec_time > 0: | ||
duration = 2 * prefix.max_exec_time | ||
else: | ||
duration = self.UNKNOWN_TASK_DURATION | ||
return duration | ||
|
||
def _calc_occupancy( | ||
self, | ||
task_prefix_count: dict[str, int], | ||
network_occ: float, | ||
) -> float: | ||
res = 0.0 | ||
for prefix_name, count in task_prefix_count.items(): | ||
# TODO: Deal with unknown tasks better | ||
prefix = self.task_prefixes[prefix_name] | ||
assert prefix is not None | ||
duration = prefix.duration_average | ||
if duration < 0: | ||
if prefix.max_exec_time > 0: | ||
duration = 2 * prefix.max_exec_time | ||
else: | ||
duration = self.UNKNOWN_TASK_DURATION | ||
duration = self._get_prefix_duration(self.task_prefixes[prefix_name]) | ||
res += duration * count | ||
occ = res + network_occ / self.bandwidth | ||
assert occ >= 0, (occ, res, network_occ, self.bandwidth) | ||
|
@@ -2536,13 +2545,6 @@ | |
action=startstop["action"], | ||
) | ||
|
||
s = self.unknown_durations.pop(ts.prefix.name, set()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've moved this into the stealing plugin. |
||
steal = self.extensions.get("stealing") | ||
if steal: | ||
for tts in s: | ||
if tts.processing_on: | ||
steal.recalculate_cost(tts) | ||
|
||
############################ | ||
# Update State Information # | ||
############################ | ||
|
@@ -3171,26 +3173,6 @@ | |
nbytes = sum(dts.nbytes for dts in deps) | ||
return nbytes / self.bandwidth | ||
|
||
def get_task_duration(self, ts: TaskState) -> float: | ||
"""Get the estimated computation cost of the given task (not including | ||
any communication cost). | ||
|
||
If no data has been observed, value of | ||
`distributed.scheduler.default-task-durations` are used. If none is set | ||
for this task, `distributed.scheduler.unknown-task-duration` is used | ||
instead. | ||
""" | ||
prefix = ts.prefix | ||
duration: float = prefix.duration_average | ||
if duration >= 0: | ||
return duration | ||
|
||
s = self.unknown_durations.get(prefix.name) | ||
if s is None: | ||
self.unknown_durations[prefix.name] = s = set() | ||
s.add(ts) | ||
return self.UNKNOWN_TASK_DURATION | ||
|
||
def valid_workers(self, ts: TaskState) -> set[WorkerState] | None: | ||
"""Return set of currently valid workers for key | ||
|
||
|
@@ -3569,20 +3551,15 @@ | |
elif ts.state != "erred" and not ts.waiters: | ||
recommendations[ts.key] = "released" | ||
|
||
def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: | ||
def _task_to_msg(self, ts: TaskState) -> dict[str, Any]: | ||
"""Convert a single computational task to a message""" | ||
# FIXME: The duration attribute is not used on worker. We could save ourselves the | ||
# time to compute and submit this | ||
if duration < 0: | ||
duration = self.get_task_duration(ts) | ||
ts.run_id = next(TaskState._run_id_iterator) | ||
assert ts.priority, ts | ||
msg: dict[str, Any] = { | ||
"op": "compute-task", | ||
"key": ts.key, | ||
"run_id": ts.run_id, | ||
"priority": ts.priority, | ||
"duration": duration, | ||
"stimulus_id": f"compute-task-{time()}", | ||
"who_has": { | ||
dts.key: tuple(ws.address for ws in (dts.who_has or ())) | ||
|
@@ -6003,12 +5980,10 @@ | |
cleanup_delay, remove_client_from_events | ||
) | ||
|
||
def send_task_to_worker( | ||
self, worker: str, ts: TaskState, duration: float = -1 | ||
) -> None: | ||
def send_task_to_worker(self, worker: str, ts: TaskState) -> None: | ||
"""Send a single computational task to a worker""" | ||
try: | ||
msg = self._task_to_msg(ts, duration) | ||
msg = self._task_to_msg(ts) | ||
self.worker_send(worker, msg) | ||
except Exception as e: | ||
logger.exception(e) | ||
|
@@ -8859,10 +8834,7 @@ | |
queued = take(100, concat([self.queued, self.unrunnable.keys()])) | ||
queued_occupancy = 0 | ||
for ts in queued: | ||
if ts.prefix.duration_average == -1: | ||
queued_occupancy += self.UNKNOWN_TASK_DURATION | ||
else: | ||
queued_occupancy += ts.prefix.duration_average | ||
queued_occupancy += self._get_prefix_duration(ts.prefix) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a test for this, but the old version was definitely inconsistent. |
||
|
||
tasks_ready = len(self.queued) + len(self.unrunnable) | ||
if tasks_ready > 100: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,6 +88,9 @@ class WorkStealing(SchedulerPlugin): | |
metrics: dict[str, dict[int, float]] | ||
_in_flight_event: asyncio.Event | ||
_request_counter: int | ||
#: Tasks with unknown duration, grouped by prefix | ||
#: {task prefix: {ts, ts, ...}} | ||
unknown_durations: dict[str, set[TaskState]] | ||
|
||
def __init__(self, scheduler: Scheduler): | ||
self.scheduler = scheduler | ||
|
@@ -111,6 +114,7 @@ def __init__(self, scheduler: Scheduler): | |
self.in_flight_occupancy = defaultdict(int) | ||
self.in_flight_tasks = defaultdict(int) | ||
self._in_flight_event = asyncio.Event() | ||
self.unknown_durations = {} | ||
self.metrics = { | ||
"request_count_total": defaultdict(int), | ||
"request_cost_total": defaultdict(int), | ||
|
@@ -188,6 +192,13 @@ def transition( | |
ts = self.scheduler.tasks[key] | ||
self.remove_key_from_stealable(ts) | ||
self._remove_from_in_flight(ts) | ||
|
||
if finish == "memory": | ||
s = self.unknown_durations.pop(ts.prefix.name, set()) | ||
for tts in s: | ||
if tts.processing_on: | ||
self.recalculate_cost(tts) | ||
|
||
if finish == "processing": | ||
ts = self.scheduler.tasks[key] | ||
self.put_key_in_stealable(ts) | ||
|
@@ -223,13 +234,27 @@ def recalculate_cost(self, ts: TaskState) -> None: | |
|
||
def put_key_in_stealable(self, ts: TaskState) -> None: | ||
cost_multiplier, level = self.steal_time_ratio(ts) | ||
if cost_multiplier is not None: | ||
assert level is not None | ||
assert ts.processing_on | ||
ws = ts.processing_on | ||
worker = ws.address | ||
self.stealable[worker][level].add(ts) | ||
self.key_stealable[ts] = (worker, level) | ||
|
||
if cost_multiplier is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't we start with this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adjusted. |
||
return | ||
|
||
prefix = ts.prefix | ||
duration = self.scheduler._get_prefix_duration(prefix) | ||
|
||
assert level is not None | ||
assert ts.processing_on | ||
ws = ts.processing_on | ||
worker = ws.address | ||
self.stealable[worker][level].add(ts) | ||
self.key_stealable[ts] = (worker, level) | ||
|
||
if duration == ts.prefix.duration_average: | ||
return | ||
|
||
if prefix.name not in self.unknown_durations: | ||
self.unknown_durations[prefix.name] = set() | ||
|
||
self.unknown_durations[prefix.name].add(ts) | ||
|
||
def remove_key_from_stealable(self, ts: TaskState) -> None: | ||
result = self.key_stealable.pop(ts, None) | ||
|
@@ -255,7 +280,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non | |
if not ts.dependencies: # no dependencies fast path | ||
return 0, 0 | ||
|
||
compute_time = self.scheduler.get_task_duration(ts) | ||
compute_time = self.scheduler._get_prefix_duration(ts.prefix) | ||
|
||
if not compute_time: | ||
# occupancy/ws.processing[ts] is only allowed to be zero for | ||
|
@@ -301,12 +326,9 @@ def move_task_request( | |
|
||
# TODO: occupancy no longer concats linearly so we can't easily | ||
# assume that the network cost would go down by that much | ||
victim_duration = self.scheduler.get_task_duration( | ||
ts | ||
) + self.scheduler.get_comm_cost(ts, victim) | ||
thief_duration = self.scheduler.get_task_duration( | ||
ts | ||
) + self.scheduler.get_comm_cost(ts, thief) | ||
compute = self.scheduler._get_prefix_duration(ts.prefix) | ||
victim_duration = compute + self.scheduler.get_comm_cost(ts, victim) | ||
thief_duration = compute + self.scheduler.get_comm_cost(ts, thief) | ||
|
||
self.scheduler.stream_comms[victim.address].send( | ||
{"op": "steal-request", "key": key, "stimulus_id": stimulus_id} | ||
|
@@ -457,8 +479,7 @@ def balance(self) -> None: | |
occ_victim = self._combined_occupancy(victim) | ||
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief) | ||
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim) | ||
compute = self.scheduler.get_task_duration(ts) | ||
|
||
compute = self.scheduler._get_prefix_duration(ts.prefix) | ||
if ( | ||
occ_thief + comm_cost_thief + compute | ||
<= occ_victim - (comm_cost_victim + compute) / 2 | ||
|
@@ -483,6 +504,8 @@ def balance(self) -> None: | |
occ_thief = self._combined_occupancy(thief) | ||
nproc_thief = self._combined_nprocessing(thief) | ||
|
||
# FIXME: In the worst case, the victim may have 3x the amount of work | ||
# of the thief when this aborts balancing. | ||
if not self.scheduler.is_unoccupied( | ||
thief, occ_thief, nproc_thief | ||
): | ||
|
@@ -514,6 +537,7 @@ def restart(self, scheduler: Any) -> None: | |
s.clear() | ||
|
||
self.key_stealable.clear() | ||
self.unknown_durations.clear() | ||
|
||
def story(self, *keys_or_ts: str | TaskState) -> list: | ||
keys = {key.key if not isinstance(key, str) else key for key in keys_or_ts} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been moved into stealing.