Skip to content

Commit

Permalink
Merge pull request #49 from xadrianzetx/default-to-local
Browse files Browse the repository at this point in the history
Default to `multiprocessing` backend when `LocalCluster` is used
  • Loading branch information
xadrianzetx authored Nov 11, 2022
2 parents 45dd8f0 + 6531f06 commit 7217c74
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 16 deletions.
22 changes: 8 additions & 14 deletions optuna_distributed/managers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from dask.distributed import Client
from dask.distributed import Future
from dask.distributed import LocalCluster
from dask.distributed import Variable
from optuna.exceptions import TrialPruned
from optuna.study import Study
Expand Down Expand Up @@ -103,7 +102,6 @@ def __init__(self, client: Client, n_trials: int, heartbeat_interval: int = 60)
self._completed_trials = 0
self._public_channel = str(uuid.uuid4())
self._synchronizer = _StateSynchronizer()
self._is_cluster_remote = not isinstance(client.cluster, LocalCluster)

# Manager has write access to its own message queue as a sort of health check.
# Basically that means we can pump event loop from callbacks running in
Expand Down Expand Up @@ -148,7 +146,7 @@ def _add_task_context(self, trials: List[DistributedTrial]) -> List[_TaskContext
return trials_with_context

def create_futures(self, study: Study, objective: ObjectiveFuncType) -> None:
distributable = _distributable(objective, with_supervisor=self._is_cluster_remote)
distributable = _distributable(objective)
trials = self._add_task_context(self._create_trials(study))
self._futures = self._client.map(distributable, trials, pure=False)
for future in self._futures:
Expand Down Expand Up @@ -178,13 +176,11 @@ def get_connection(self, trial_id: int) -> IPCPrimitive:
return Queue(self._private_channels[trial_id])

def stop_optimization(self) -> None:
# Only want to cleanup cluster that does not belong to us.
# TODO(xadrianzetx) Notebooks might be a special case (cleanup even with LocalCluster).
self._client.cancel(self._futures)
if self._is_cluster_remote:
# Twice the timeout of task connection.
# This way even tasks waiting for message will have chance to exit.
self._synchronizer.emit_stop_and_wait(patience=10)
# Twice the timeout of task connection.
# This way even tasks waiting for message will have chance to exit.
# TODO(xadrianzetx) Accept patience as an argument to `stop_optimization`.
self._synchronizer.emit_stop_and_wait(patience=10)

def should_end_optimization(self) -> bool:
return self._completed_trials == self._n_trials
Expand All @@ -193,7 +189,7 @@ def register_trial_exit(self, trial_id: int) -> None:
self._completed_trials += 1


def _distributable(func: ObjectiveFuncType, with_supervisor: bool) -> DistributableWithContext:
def _distributable(func: ObjectiveFuncType) -> DistributableWithContext:
def _wrapper(context: _TaskContext) -> None:
task_state = Variable(context.state_id)
if task_state.get() != _TaskState.WAITING:
Expand All @@ -203,10 +199,8 @@ def _wrapper(context: _TaskContext) -> None:
message: Message

try:
if with_supervisor:
args = (threading.get_ident(), context)
Thread(target=_task_supervisor, args=args, daemon=True).start()

args = (threading.get_ident(), context)
Thread(target=_task_supervisor, args=args, daemon=True).start()
value_or_values = func(context.trial)
message = CompletedMessage(context.trial.trial_id, value_or_values)
context.trial.connection.put(message)
Expand Down
3 changes: 2 additions & 1 deletion optuna_distributed/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Union

from dask.distributed import Client
from dask.distributed import LocalCluster
from optuna.distributions import BaseDistribution
from optuna.study import Study
from optuna.study import StudyDirection
Expand Down Expand Up @@ -174,7 +175,7 @@ def optimize(
terminal = Terminal(show_progress_bar, n_trials, timeout)
manager = (
DistributedOptimizationManager(self._client, n_trials)
if self._client is not None
if self._client is not None and not isinstance(self._client.cluster, LocalCluster)
else LocalOptimizationManager(n_trials, n_jobs)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _objective(trial: DistributedTrial) -> float:

# Simulate scenario where task run was repeated.
# https://stackoverflow.com/a/41965766
func = _distributable(_objective, with_supervisor=False)
func = _distributable(_objective)
context = _TaskContext(DistributedTrial(0, Mock()), stop_flag="foo", state_id=state_id)
for _ in range(5):
client.submit(func, context).result()
Expand Down

0 comments on commit 7217c74

Please sign in to comment.