Skip to content

Commit

Permalink
Merge pull request #32 from xadrianzetx/task-dedupe-v2
Browse files Browse the repository at this point in the history
Task state based deduplication
  • Loading branch information
xadrianzetx authored Oct 19, 2022
2 parents c4aa7f8 + d37f636 commit 8101460
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 46 deletions.
4 changes: 3 additions & 1 deletion optuna_distributed/managers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ def register_trial_exit(self, trial_id: int) -> None:

def _distributable(func: ObjectiveFuncType, with_supervisor: bool) -> DistributableWithContext:
def _wrapper(context: _TaskContext) -> None:
# FIXME: Re-introduce task deduplication.
task_state = Variable(context.state_id)
if task_state.get() != _TaskState.WAITING:
return

task_state.set(_TaskState.RUNNING)
message: Message

Expand Down
2 changes: 0 additions & 2 deletions optuna_distributed/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from optuna_distributed.messages.property import TrialProperty
from optuna_distributed.messages.property import TrialPropertyMessage
from optuna_distributed.messages.pruned import PrunedMessage
from optuna_distributed.messages.repeated import RepeatedTrialMessage
from optuna_distributed.messages.report import ReportMessage
from optuna_distributed.messages.response import ResponseMessage
from optuna_distributed.messages.setattr import AttributeType
Expand All @@ -21,7 +20,6 @@
"SuggestMessage",
"CompletedMessage",
"FailedMessage",
"RepeatedTrialMessage",
"PrunedMessage",
"ReportMessage",
"ShouldPruneMessage",
Expand Down
33 changes: 0 additions & 33 deletions optuna_distributed/messages/repeated.py

This file was deleted.

25 changes: 25 additions & 0 deletions tests/test_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import multiprocessing
import sys
import time
from unittest.mock import Mock
import uuid

from dask.distributed import Client
from dask.distributed import Variable
Expand All @@ -13,7 +15,9 @@
from optuna_distributed.managers import LocalOptimizationManager
from optuna_distributed.managers import ObjectiveFuncType
from optuna_distributed.managers.distributed import _StateSynchronizer
from optuna_distributed.managers.distributed import _TaskContext
from optuna_distributed.managers.distributed import _TaskState
from optuna_distributed.managers.distributed import _distributable
from optuna_distributed.messages import CompletedMessage
from optuna_distributed.messages import HeartbeatMessage
from optuna_distributed.messages import ResponseMessage
Expand Down Expand Up @@ -117,6 +121,27 @@ def _objective(trial: DistributedTrial) -> float:
break


def test_distributed_task_deduped(client: Client) -> None:
def _objective(trial: DistributedTrial) -> float:
run_count = Variable("run_count")
run_count.set(run_count.get() + 1)
return 0.0

run_count = Variable("run_count")
run_count.set(0)
state_id = uuid.uuid4().hex
Variable(state_id).set(_TaskState.WAITING)

# Simulate scenario where task run was repeated.
# https://stackoverflow.com/a/41965766
func = _distributable(_objective, with_supervisor=False)
context = _TaskContext(DistributedTrial(0, Mock()), stop_flag="foo", state_id=state_id)
for _ in range(5):
client.submit(func, context).result()

assert run_count.get() == 1


def test_synchronizer_optimization_enabled() -> None:
synchronizer = _StateSynchronizer()
optimization_enabled = Variable(synchronizer.stop_flag)
Expand Down
10 changes: 0 additions & 10 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from optuna_distributed.messages import HeartbeatMessage
from optuna_distributed.messages import Message
from optuna_distributed.messages import PrunedMessage
from optuna_distributed.messages import RepeatedTrialMessage
from optuna_distributed.messages import ReportMessage
from optuna_distributed.messages import ResponseMessage
from optuna_distributed.messages import SetAttributeMessage
Expand Down Expand Up @@ -179,15 +178,6 @@ def test_should_prune(study: Study, manager: MockOptimizationManager) -> None:
assert trial[0]._trial_id == 0


def test_repeated_trial(study: Study, manager: MockOptimizationManager) -> None:
msg = RepeatedTrialMessage(0)
assert not msg.closing

study.tell(0, state=TrialState.PRUNED)
msg.process(study, manager)
assert _message_responds_with(True, manager=manager)


def test_report_intermediate(study: Study, manager: MockOptimizationManager) -> None:
msg = ReportMessage(0, value=0.0, step=1)
assert not msg.closing
Expand Down

0 comments on commit 8101460

Please sign in to comment.