Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fancy terminal output #38

Merged
merged 6 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions optuna_distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
def _setup_logger() -> None:
import logging

import colorlog
from rich.logging import RichHandler

fmt = colorlog.ColoredFormatter(
"%(log_color)s[%(levelname)1.1s %(asctime)s]%(reset)s %(message)s"
)
handler = logging.StreamHandler()
handler = RichHandler(show_path=False)
fmt = logging.Formatter(fmt="%(message)s", datefmt="[%X]")
handler.setFormatter(fmt)
root_logger = logging.getLogger(__name__)
root_logger.addHandler(handler)
Expand Down
33 changes: 11 additions & 22 deletions optuna_distributed/eventloop.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from datetime import datetime
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type

# FIXME: We should probably implement our own progress bar.
from optuna.progress_bar import _ProgressBar
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState

from optuna_distributed.managers import ObjectiveFuncType
from optuna_distributed.managers import OptimizationManager
from optuna_distributed.terminal import Terminal


class EventLoop:
Expand Down Expand Up @@ -45,32 +41,23 @@ def __init__(

def run(
self,
n_trials: Optional[int] = None,
terminal: Terminal,
timeout: Optional[float] = None,
catch: Tuple[Type[Exception], ...] = (),
callbacks: Optional[List[Callable[[Study, FrozenTrial], None]]] = None,
show_progress_bar: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""Starts the event loop.

Args:
n_trials:
The number of trials for each process.
terminal:
An instance of :obj:`optuna_distributed.terminal.Terminal`.
timeout:
Stops study after the given number of second(s).
catch:
A tuple of exceptions to ignore if any is raised while optimizing a function.
callbacks:
List of callback functions that are invoked at the end of each trial. Not supported
at the moment.
show_progress_bar:
A flag to include tqdm-style progress bar.
"""
time_start = datetime.now()
progress_bar = _ProgressBar(show_progress_bar, n_trials, timeout)

self.manager.create_futures(self.study, self.objective)
for message in self.manager.get_message():
try:
Expand All @@ -80,21 +67,23 @@ def run(

except Exception as e:
if not isinstance(e, catch):
self.manager.stop_optimization()
self._fail_unfinished_trials()
with terminal.spin_while_trials_interrupted():
self.manager.stop_optimization()
self._fail_unfinished_trials()
raise

elapsed = (datetime.now() - time_start).total_seconds()
if timeout is not None and elapsed > timeout:
self.manager.stop_optimization()
with terminal.spin_while_trials_interrupted():
self.manager.stop_optimization()
break

if message.closing:
progress_bar.update(elapsed)
terminal.update_progress_bar()

# TODO(xadrianzetx): Call callbacks here.
if self.manager.should_end_optimization():
progress_bar.close()
terminal.close_progress_bar()
break

def _fail_unfinished_trials(self) -> None:
Expand Down
4 changes: 0 additions & 4 deletions optuna_distributed/managers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import ctypes
from dataclasses import dataclass
from enum import IntEnum
import logging
import sys
import threading
from threading import Thread
Expand Down Expand Up @@ -38,7 +37,6 @@


DistributableWithContext = Callable[["_TaskContext"], None]
_logger = logging.getLogger(__name__)


class WorkerInterrupted(Exception):
Expand Down Expand Up @@ -77,12 +75,10 @@ def set_initial_state(self) -> str:
def emit_stop_and_wait(self, patience: int) -> None:
self._optimization_enabled.set(False)
disabled_at = time.time()
_logger.info("Interrupting running tasks...")
while any(state.get() == _TaskState.RUNNING for state in self._task_states):
if time.time() - disabled_at > patience:
raise TimeoutError("Timed out while trying to interrupt running tasks.")
time.sleep(0.1)
_logger.info("All tasks have been stopped.")


class DistributedOptimizationManager(OptimizationManager):
Expand Down
6 changes: 0 additions & 6 deletions optuna_distributed/managers/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import multiprocessing
from multiprocessing import Pipe as MultiprocessingPipe
from multiprocessing import Process
Expand Down Expand Up @@ -30,9 +29,6 @@
from optuna_distributed.eventloop import EventLoop


_logger = logging.getLogger(__name__)


class LocalOptimizationManager(OptimizationManager):
"""Controls optimization process on local machine.

Expand Down Expand Up @@ -103,12 +99,10 @@ def get_connection(self, trial_id: int) -> IPCPrimitive:
return Pipe(self._pool[trial_id])

def stop_optimization(self) -> None:
_logger.info("Interrupting running tasks...")
for process in self._processes:
if process.is_alive():
process.kill()
process.join(timeout=10.0)
_logger.info("All tasks have been stopped.")

def should_end_optimization(self) -> bool:
return len(self._pool) == 0 and self._trials_remaining == 0
Expand Down
12 changes: 8 additions & 4 deletions optuna_distributed/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from optuna_distributed.managers import DistributedOptimizationManager
from optuna_distributed.managers import LocalOptimizationManager
from optuna_distributed.managers import ObjectiveFuncType
from optuna_distributed.terminal import Terminal


if TYPE_CHECKING:
Expand Down Expand Up @@ -153,14 +154,14 @@ def optimize(
n_trials:
The number of trials to run in total.
timeout:
Stop study after the given number of second(s). Currently noop.
Stop study after the given number of second(s).
n_jobs:
The number of parallel jobs when using multiprocessing backend. Values less than
one or greater than :obj:`multiprocessing.cpu_count()` will default to number of
logical CPU cores available.
catch:
A study continues to run even when a trial raises one of the exceptions specified
in this argument. Currently noop.
in this argument.
callbacks:
List of callback functions that are invoked at the end of each trial. Currently
not supported.
Expand All @@ -170,6 +171,7 @@ def optimize(
if n_trials is None:
raise ValueError("Only finite number of trials supported at the moment.")

terminal = Terminal(show_progress_bar, n_trials, timeout)
manager = (
DistributedOptimizationManager(self._client, n_trials)
if self._client is not None
Expand All @@ -178,10 +180,12 @@ def optimize(

try:
event_loop = EventLoop(self._study, manager, objective=func)
event_loop.run(n_trials, timeout, catch, callbacks, show_progress_bar)
event_loop.run(terminal, timeout, catch)

except KeyboardInterrupt:
manager.stop_optimization()
with terminal.spin_while_trials_interrupted():
manager.stop_optimization()

states = (TrialState.RUNNING, TrialState.WAITING)
trials = self._study.get_trials(deepcopy=False, states=states)
for trial in trials:
Expand Down
53 changes: 53 additions & 0 deletions optuna_distributed/terminal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Optional

from rich.progress import BarColumn
from rich.progress import Progress
from rich.progress import TaskProgressColumn
from rich.progress import TextColumn
from rich.progress import TimeElapsedColumn
from rich.status import Status
from rich.style import Style


class Terminal:
"""Provides styled terminal output.

Args:
show_progress_bar:
Enables progress bar.
n_trials:
The number of trials to run in total.
timeout:
Stops study after the given number of second(s).
"""

def __init__(
self, show_progress_bar: bool, n_trials: int, timeout: Optional[float] = None
) -> None:
self._timeout = timeout
self._progbar = Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(complete_style=Style(color="light_coral")),
TaskProgressColumn(),
TimeElapsedColumn(),
transient=True,
)

self._task = self._progbar.add_task("[blue]Running trials...[/blue]", total=n_trials)
if show_progress_bar:
self._progbar.start()

def update_progress_bar(self) -> None:
"""Advance progress bar by one trial."""
self._progbar.advance(self._task)

def close_progress_bar(self) -> None:
"""Closes progress bar."""
self._progbar.stop()

def spin_while_trials_interrupted(self) -> Status:
"""Renders spinner animation while trials are being interrupted."""
self._progbar.stop()
return self._progbar.console.status(
"[blue]Interrupting running trials...[/blue]", spinner_style=Style(color="blue") # type: ignore # noqa: E501
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ classifiers = [
dependencies = [
"optuna",
"dask[distributed]",
"colorlog",
"rich",
"typing-extensions",
]
dynamic = ["version", "readme"]
Expand Down
9 changes: 6 additions & 3 deletions tests/test_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from optuna_distributed.eventloop import EventLoop
from optuna_distributed.managers import LocalOptimizationManager
from optuna_distributed.terminal import Terminal
from optuna_distributed.trial import DistributedTrial


Expand All @@ -17,7 +18,7 @@ def _objective(trial: DistributedTrial) -> float:
manager = LocalOptimizationManager(n_trials, n_jobs=1)
event_loop = EventLoop(study, manager, objective=_objective)
with pytest.raises(ValueError):
event_loop.run(n_trials)
event_loop.run(terminal=Terminal(show_progress_bar=False, n_trials=n_trials))


def test_catches_on_trial_exception() -> None:
Expand All @@ -28,7 +29,9 @@ def _objective(trial: DistributedTrial) -> float:
study = optuna.create_study()
manager = LocalOptimizationManager(n_trials, n_jobs=1)
event_loop = EventLoop(study, manager, objective=_objective)
event_loop.run(n_trials, catch=(ValueError,))
event_loop.run(
terminal=Terminal(show_progress_bar=False, n_trials=n_trials), catch=(ValueError,)
)


def test_stops_optimization() -> None:
Expand All @@ -43,6 +46,6 @@ def _objective(trial: DistributedTrial) -> float:
manager = LocalOptimizationManager(n_trials, n_jobs=1)
event_loop = EventLoop(study, manager, objective=_objective)
started_at = time.time()
event_loop.run(n_trials, timeout=1.0)
event_loop.run(terminal=Terminal(show_progress_bar=False, n_trials=n_trials), timeout=1.0)
interrupted_execution_time = time.time() - started_at
assert interrupted_execution_time < uninterrupted_execution_time