Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent estimation of task duration between stealing, adaptive and occupancy calculation #9000

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 28 additions & 56 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,9 +1674,6 @@
#: Subset of tasks that exist in memory on more than one worker
replicated_tasks: set[TaskState]

#: Tasks with unknown duration, grouped by prefix
#: {task prefix: {ts, ts, ...}}
unknown_durations: dict[str, set[TaskState]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been moved into stealing.

task_groups: dict[str, TaskGroup]
task_prefixes: dict[str, TaskPrefix]
task_metadata: dict[Key, Any]
Expand Down Expand Up @@ -1776,7 +1773,6 @@
self.task_metadata = {}
self.total_nthreads = 0
self.total_nthreads_history = [(time(), 0)]
self.unknown_durations = {}
self.queued = queued
self.unrunnable = unrunnable
self.validate = validate
Expand Down Expand Up @@ -1855,7 +1851,6 @@
"unrunnable": self.unrunnable,
"queued": self.queued,
"n_tasks": self.n_tasks,
"unknown_durations": self.unknown_durations,
"validate": self.validate,
"tasks": self.tasks,
"task_groups": self.task_groups,
Expand Down Expand Up @@ -1907,7 +1902,6 @@
self.task_prefixes,
self.task_groups,
self.task_metadata,
self.unknown_durations,
self.replicated_tasks,
):
collection.clear()
Expand All @@ -1931,22 +1925,37 @@
self._network_occ_global,
)

def _get_prefix_duration(self, prefix: TaskPrefix) -> float:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the single source of truth for the duration estimation.

"""Get the estimated computation cost of the given task prefix
(not including any communication cost).

If no data has been observed, value of
`distributed.scheduler.default-task-durations` are used. If none is set
for this task, `distributed.scheduler.unknown-task-duration` is used
instead.

See Also
--------
WorkStealing.get_task_duration
"""
# TODO: Deal with unknown tasks better
assert prefix is not None
duration = prefix.duration_average
if duration < 0:
if prefix.max_exec_time > 0:
duration = 2 * prefix.max_exec_time

Check warning on line 1946 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L1946

Added line #L1946 was not covered by tests
else:
duration = self.UNKNOWN_TASK_DURATION
return duration

def _calc_occupancy(
self,
task_prefix_count: dict[str, int],
network_occ: float,
) -> float:
res = 0.0
for prefix_name, count in task_prefix_count.items():
# TODO: Deal with unknown tasks better
prefix = self.task_prefixes[prefix_name]
assert prefix is not None
duration = prefix.duration_average
if duration < 0:
if prefix.max_exec_time > 0:
duration = 2 * prefix.max_exec_time
else:
duration = self.UNKNOWN_TASK_DURATION
duration = self._get_prefix_duration(self.task_prefixes[prefix_name])
res += duration * count
occ = res + network_occ / self.bandwidth
assert occ >= 0, (occ, res, network_occ, self.bandwidth)
Expand Down Expand Up @@ -2536,13 +2545,6 @@
action=startstop["action"],
)

s = self.unknown_durations.pop(ts.prefix.name, set())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved this into the stealing plugin.

steal = self.extensions.get("stealing")
if steal:
for tts in s:
if tts.processing_on:
steal.recalculate_cost(tts)

############################
# Update State Information #
############################
Expand Down Expand Up @@ -3171,26 +3173,6 @@
nbytes = sum(dts.nbytes for dts in deps)
return nbytes / self.bandwidth

def get_task_duration(self, ts: TaskState) -> float:
"""Get the estimated computation cost of the given task (not including
any communication cost).

If no data has been observed, value of
`distributed.scheduler.default-task-durations` are used. If none is set
for this task, `distributed.scheduler.unknown-task-duration` is used
instead.
"""
prefix = ts.prefix
duration: float = prefix.duration_average
if duration >= 0:
return duration

s = self.unknown_durations.get(prefix.name)
if s is None:
self.unknown_durations[prefix.name] = s = set()
s.add(ts)
return self.UNKNOWN_TASK_DURATION

def valid_workers(self, ts: TaskState) -> set[WorkerState] | None:
"""Return set of currently valid workers for key

Expand Down Expand Up @@ -3569,20 +3551,15 @@
elif ts.state != "erred" and not ts.waiters:
recommendations[ts.key] = "released"

def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
def _task_to_msg(self, ts: TaskState) -> dict[str, Any]:
"""Convert a single computational task to a message"""
# FIXME: The duration attribute is not used on worker. We could save ourselves the
# time to compute and submit this
if duration < 0:
duration = self.get_task_duration(ts)
ts.run_id = next(TaskState._run_id_iterator)
assert ts.priority, ts
msg: dict[str, Any] = {
"op": "compute-task",
"key": ts.key,
"run_id": ts.run_id,
"priority": ts.priority,
"duration": duration,
"stimulus_id": f"compute-task-{time()}",
"who_has": {
dts.key: tuple(ws.address for ws in (dts.who_has or ()))
Expand Down Expand Up @@ -6003,12 +5980,10 @@
cleanup_delay, remove_client_from_events
)

def send_task_to_worker(
self, worker: str, ts: TaskState, duration: float = -1
) -> None:
def send_task_to_worker(self, worker: str, ts: TaskState) -> None:
"""Send a single computational task to a worker"""
try:
msg = self._task_to_msg(ts, duration)
msg = self._task_to_msg(ts)
self.worker_send(worker, msg)
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -8859,10 +8834,7 @@
queued = take(100, concat([self.queued, self.unrunnable.keys()]))
queued_occupancy = 0
for ts in queued:
if ts.prefix.duration_average == -1:
queued_occupancy += self.UNKNOWN_TASK_DURATION
else:
queued_occupancy += ts.prefix.duration_average
queued_occupancy += self._get_prefix_duration(ts.prefix)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a test for this, but the old version was definitely inconsistent.


tasks_ready = len(self.queued) + len(self.unrunnable)
if tasks_ready > 100:
Expand Down
56 changes: 40 additions & 16 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class WorkStealing(SchedulerPlugin):
metrics: dict[str, dict[int, float]]
_in_flight_event: asyncio.Event
_request_counter: int
#: Tasks with unknown duration, grouped by prefix
#: {task prefix: {ts, ts, ...}}
unknown_durations: dict[str, set[TaskState]]

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
Expand All @@ -111,6 +114,7 @@ def __init__(self, scheduler: Scheduler):
self.in_flight_occupancy = defaultdict(int)
self.in_flight_tasks = defaultdict(int)
self._in_flight_event = asyncio.Event()
self.unknown_durations = {}
self.metrics = {
"request_count_total": defaultdict(int),
"request_cost_total": defaultdict(int),
Expand Down Expand Up @@ -188,6 +192,13 @@ def transition(
ts = self.scheduler.tasks[key]
self.remove_key_from_stealable(ts)
self._remove_from_in_flight(ts)

if finish == "memory":
s = self.unknown_durations.pop(ts.prefix.name, set())
for tts in s:
if tts.processing_on:
self.recalculate_cost(tts)

if finish == "processing":
ts = self.scheduler.tasks[key]
self.put_key_in_stealable(ts)
Expand Down Expand Up @@ -223,13 +234,27 @@ def recalculate_cost(self, ts: TaskState) -> None:

def put_key_in_stealable(self, ts: TaskState) -> None:
cost_multiplier, level = self.steal_time_ratio(ts)
if cost_multiplier is not None:
assert level is not None
assert ts.processing_on
ws = ts.processing_on
worker = ws.address
self.stealable[worker][level].add(ts)
self.key_stealable[ts] = (worker, level)

if cost_multiplier is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we start with this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted.

return

prefix = ts.prefix
duration = self.scheduler._get_prefix_duration(prefix)

assert level is not None
assert ts.processing_on
ws = ts.processing_on
worker = ws.address
self.stealable[worker][level].add(ts)
self.key_stealable[ts] = (worker, level)

if duration == ts.prefix.duration_average:
return

if prefix.name not in self.unknown_durations:
self.unknown_durations[prefix.name] = set()

self.unknown_durations[prefix.name].add(ts)

def remove_key_from_stealable(self, ts: TaskState) -> None:
result = self.key_stealable.pop(ts, None)
Expand All @@ -255,7 +280,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non
if not ts.dependencies: # no dependencies fast path
return 0, 0

compute_time = self.scheduler.get_task_duration(ts)
compute_time = self.scheduler._get_prefix_duration(ts.prefix)

if not compute_time:
# occupancy/ws.processing[ts] is only allowed to be zero for
Expand Down Expand Up @@ -301,12 +326,9 @@ def move_task_request(

# TODO: occupancy no longer concats linearly so we can't easily
# assume that the network cost would go down by that much
victim_duration = self.scheduler.get_task_duration(
ts
) + self.scheduler.get_comm_cost(ts, victim)
thief_duration = self.scheduler.get_task_duration(
ts
) + self.scheduler.get_comm_cost(ts, thief)
compute = self.scheduler._get_prefix_duration(ts.prefix)
victim_duration = compute + self.scheduler.get_comm_cost(ts, victim)
thief_duration = compute + self.scheduler.get_comm_cost(ts, thief)

self.scheduler.stream_comms[victim.address].send(
{"op": "steal-request", "key": key, "stimulus_id": stimulus_id}
Expand Down Expand Up @@ -457,8 +479,7 @@ def balance(self) -> None:
occ_victim = self._combined_occupancy(victim)
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
compute = self.scheduler.get_task_duration(ts)

compute = self.scheduler._get_prefix_duration(ts.prefix)
if (
occ_thief + comm_cost_thief + compute
<= occ_victim - (comm_cost_victim + compute) / 2
Expand All @@ -483,6 +504,8 @@ def balance(self) -> None:
occ_thief = self._combined_occupancy(thief)
nproc_thief = self._combined_nprocessing(thief)

# FIXME: In the worst case, the victim may have 3x the amount of work
# of the thief when this aborts balancing.
if not self.scheduler.is_unoccupied(
thief, occ_thief, nproc_thief
):
Expand Down Expand Up @@ -514,6 +537,7 @@ def restart(self, scheduler: Any) -> None:
s.clear()

self.key_stealable.clear()
self.unknown_durations.clear()

def story(self, *keys_or_ts: str | TaskState) -> list:
keys = {key.key if not isinstance(key, str) else key for key in keys_or_ts}
Expand Down
21 changes: 12 additions & 9 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,24 +2788,26 @@ async def test_retire_workers_bad_params(c, s, a, b):
@gen_cluster(
client=True, config={"distributed.scheduler.default-task-durations": {"inc": 100}}
)
async def test_get_task_duration(c, s, a, b):
async def test_get_prefix_duration(c, s, a, b):
future = c.submit(inc, 1)
await future
assert 10 < s.task_prefixes["inc"].duration_average < 100

ts_pref1 = s.new_task("inc-abcdefab", None, "released")
assert 10 < s.get_task_duration(ts_pref1) < 100
assert 10 < s._get_prefix_duration(ts_pref1.prefix) < 100

extension = s.extensions["stealing"]
# make sure get_task_duration adds TaskStates to unknown dict
assert len(s.unknown_durations) == 0
assert len(extension.unknown_durations) == 0
x = c.submit(slowinc, 1, delay=0.5)
while len(s.tasks) < 3:
await asyncio.sleep(0.01)

ts = s.tasks[x.key]
assert s.get_task_duration(ts) == 0.5 # default
assert len(s.unknown_durations) == 1
assert len(s.unknown_durations["slowinc"]) == 1
assert s._get_prefix_duration(ts.prefix) == 0.5 # default

assert len(extension.unknown_durations) == 1
assert len(extension.unknown_durations["slowinc"]) == 1


@gen_cluster(client=True)
Expand Down Expand Up @@ -3338,10 +3340,11 @@ async def test_unknown_task_duration_config(client, s, a, b):
future = client.submit(slowinc, 1)
while not s.tasks:
await asyncio.sleep(0.001)
assert sum(s.get_task_duration(ts) for ts in s.tasks.values()) == 3600
assert len(s.unknown_durations) == 1
assert sum(s._get_prefix_duration(ts.prefix) for ts in s.tasks.values()) == 3600
extension = s.extensions["stealing"]
assert len(extension.unknown_durations) == 1
await wait(future)
assert len(s.unknown_durations) == 0
assert len(extension.unknown_durations) == 0


@gen_cluster()
Expand Down
Loading
Loading