Skip to content

Commit

Permalink
Merge pull request #22 from xadrianzetx/catch-trial-exc
Browse files Browse the repository at this point in the history
Add `catch` functionality to distributed study
  • Loading branch information
xadrianzetx authored Sep 16, 2022
2 parents d86bd3a + 95f0b29 commit 0670498
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ if __name__ == "__main__":
But there's more! All of the core Optuna APIs, including [storages, samplers](https://github.com/xadrianzetx/optuna-distributed/blob/main/examples/simple_storages.py) and [pruners](https://github.com/xadrianzetx/optuna-distributed/blob/main/examples/simple_pruning.py) are supported!

## What's missing?
* Arguments passed to `study.optimize` - `timeout` and `catch` are currently noops.
* Optimization with `timeout` is not currently supported.
* Support for callbacks and Optuna integration modules.
* Study APIs such as [`study.stop`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.stop) can't be called from trial at the moment.
20 changes: 12 additions & 8 deletions optuna_distributed/eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,11 @@ def run(
message.process(self.study, self.manager)
self.manager.after_message(self)

except Exception:
self.manager.stop_optimization()
# TODO(xadrianzetx) Is there a better way to do this in Optuna?
states = (TrialState.RUNNING, TrialState.WAITING)
trials = self.study.get_trials(deepcopy=False, states=states)
for trial in trials:
self.study._storage.set_trial_state_values(trial._trial_id, TrialState.FAIL)
raise
except Exception as e:
if not isinstance(e, catch):
self.manager.stop_optimization()
self._fail_unfinished_trials()
raise

if message.closing:
progress_bar.update((datetime.now() - time_start).total_seconds())
Expand All @@ -97,3 +94,10 @@ def run(
# TODO(xadrianzetx): Call callbacks here.
if self.manager.should_end_optimization():
break

def _fail_unfinished_trials(self) -> None:
# TODO(xadrianzetx) Is there a better way to do this in Optuna?
states = (TrialState.RUNNING, TrialState.WAITING)
trials = self.study.get_trials(deepcopy=False, states=states)
for trial in trials:
self.study._storage.set_trial_state_values(trial._trial_id, TrialState.FAIL)
2 changes: 0 additions & 2 deletions optuna_distributed/messages/failed.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,4 @@ def process(self, study: "Study", manager: "OptimizationManager") -> None:
f"of the following error: {repr(self._exception)}",
exc_info=self._exc_info,
)
# TODO(xadrianzetx) Implement exception catching.
# https://github.com/optuna/optuna/blob/5d19e5e1f5dd9b3f9a11c74d215bd2a9c7ff43d2/optuna/study/_optimize.py#L229-L234
raise self._exception
34 changes: 34 additions & 0 deletions tests/test_eventloop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys

import optuna
import pytest

from optuna_distributed.eventloop import EventLoop
from optuna_distributed.managers import LocalOptimizationManager
from optuna_distributed.messages import FailedMessage
from optuna_distributed.trial import DistributedTrial


def test_raises_on_trial_exception() -> None:
def _objective(trial: DistributedTrial) -> None:
exception = ValueError()
trial.connection.put(FailedMessage(trial.trial_id, exception, exc_info=sys.exc_info()))

n_trials = 5
study = optuna.create_study()
manager = LocalOptimizationManager(n_trials, n_jobs=1)
event_loop = EventLoop(study, manager, objective=_objective)
with pytest.raises(ValueError):
event_loop.run(n_trials, timeout=None)


def test_catches_on_trial_exception() -> None:
def _objective(trial: DistributedTrial) -> None:
exception = ValueError()
trial.connection.put(FailedMessage(trial.trial_id, exception, exc_info=sys.exc_info()))

n_trials = 5
study = optuna.create_study()
manager = LocalOptimizationManager(n_trials, n_jobs=1)
event_loop = EventLoop(study, manager, objective=_objective)
event_loop.run(n_trials, timeout=None, catch=(ValueError,))

0 comments on commit 0670498

Please sign in to comment.