Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add catch functionality to distributed study #22

Merged
merged 3 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ if __name__ == "__main__":
But there's more! All of the core Optuna APIs, including [storages, samplers](https://github.com/xadrianzetx/optuna-distributed/blob/main/examples/simple_storages.py) and [pruners](https://github.com/xadrianzetx/optuna-distributed/blob/main/examples/simple_pruning.py) are supported!

## What's missing?
* Arguments passed to `study.optimize` - `timeout` and `catch` are currently noops.
* Optimization with `timeout` is not currently supported.
* Support for callbacks and Optuna integration modules.
* Study APIs such as [`study.stop`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.stop) can't be called from trial at the moment.
20 changes: 12 additions & 8 deletions optuna_distributed/eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,11 @@ def run(
message.process(self.study, self.manager)
self.manager.after_message(self)

except Exception:
self.manager.stop_optimization()
# TODO(xadrianzetx) Is there a better way to do this in Optuna?
states = (TrialState.RUNNING, TrialState.WAITING)
trials = self.study.get_trials(deepcopy=False, states=states)
for trial in trials:
self.study._storage.set_trial_state_values(trial._trial_id, TrialState.FAIL)
raise
except Exception as e:
if not isinstance(e, catch):
self.manager.stop_optimization()
self._fail_unfinished_trials()
raise

if message.closing:
progress_bar.update((datetime.now() - time_start).total_seconds())
Expand All @@ -97,3 +94,10 @@ def run(
# TODO(xadrianzetx): Call callbacks here.
if self.manager.should_end_optimization():
break

def _fail_unfinished_trials(self) -> None:
# TODO(xadrianzetx) Is there a better way to do this in Optuna?
states = (TrialState.RUNNING, TrialState.WAITING)
trials = self.study.get_trials(deepcopy=False, states=states)
for trial in trials:
self.study._storage.set_trial_state_values(trial._trial_id, TrialState.FAIL)
2 changes: 0 additions & 2 deletions optuna_distributed/messages/failed.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,4 @@ def process(self, study: "Study", manager: "OptimizationManager") -> None:
f"of the following error: {repr(self._exception)}",
exc_info=self._exc_info,
)
# TODO(xadrianzetx) Implement exception catching.
# https://github.com/optuna/optuna/blob/5d19e5e1f5dd9b3f9a11c74d215bd2a9c7ff43d2/optuna/study/_optimize.py#L229-L234
raise self._exception
34 changes: 34 additions & 0 deletions tests/test_eventloop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys

import optuna
import pytest

from optuna_distributed.eventloop import EventLoop
from optuna_distributed.managers import LocalOptimizationManager
from optuna_distributed.messages import FailedMessage
from optuna_distributed.trial import DistributedTrial


def test_raises_on_trial_exception() -> None:
def _objective(trial: DistributedTrial) -> None:
exception = ValueError()
trial.connection.put(FailedMessage(trial.trial_id, exception, exc_info=sys.exc_info()))

n_trials = 5
study = optuna.create_study()
manager = LocalOptimizationManager(n_trials, n_jobs=1)
event_loop = EventLoop(study, manager, objective=_objective)
with pytest.raises(ValueError):
event_loop.run(n_trials, timeout=None)


def test_catches_on_trial_exception() -> None:
def _objective(trial: DistributedTrial) -> None:
exception = ValueError()
trial.connection.put(FailedMessage(trial.trial_id, exception, exc_info=sys.exc_info()))

n_trials = 5
study = optuna.create_study()
manager = LocalOptimizationManager(n_trials, n_jobs=1)
event_loop = EventLoop(study, manager, objective=_objective)
event_loop.run(n_trials, timeout=None, catch=(ValueError,))