Skip to content

Commit

Permalink
[air output] Print single trial config + results as table (#34788)
Browse files Browse the repository at this point in the history
This PR makes a number of improvements to tackle issues uncovered in dogfooding:

1. Instead of just a print, we render a table for configs and results at the start of training (closes #34784)
2. We round float results to 5 significant numbers after the decimal point (closes #34785)
3. We track the last printed result and only print the result at the end of training if it hasn't been printed before (closes #34786)
4. We divide the results by "automatic" results and trainer-specific results (closes #34787)

Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
krfricke authored May 2, 2023
1 parent c142efa commit cac19c9
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 37 deletions.
3 changes: 2 additions & 1 deletion python/ray/air/execution/resources/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def to_placement_group(self):

def __eq__(self, other: "ResourceRequest"):
return (
self._bound == other._bound
isinstance(other, ResourceRequest)
and self._bound == other._bound
and self.head_bundle_is_empty == other.head_bundle_is_empty
)

Expand Down
166 changes: 130 additions & 36 deletions python/ray/tune/experimental/output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Optional, Tuple, Any, TYPE_CHECKING
from typing import List, Dict, Optional, Tuple, Any, TYPE_CHECKING, Collection

import contextlib
import collections
Expand All @@ -11,7 +11,6 @@
import numpy as np
import os
import pandas as pd
from ray._private.thirdparty.tabulate.tabulate import tabulate
import textwrap
import time

Expand All @@ -24,9 +23,14 @@

import ray
from ray._private.dict import unflattened_lookup
from ray._private.thirdparty.tabulate.tabulate import (
tabulate,
TableFormat,
Line,
DataRow,
)
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint
from ray.tune.callback import Callback
from ray.tune.logger import pretty_print
from ray.tune.result import (
AUTO_RESULT_KEYS,
EPISODE_REWARD_MEAN,
Expand Down Expand Up @@ -356,6 +360,91 @@ def _best_trial_str(
)


def _render_table_item(key: str, item: Any, prefix: str = ""):
key = prefix + key
if isinstance(item, float):
# tabulate does not work well with mixed-type columns, so we format
# numbers ourselves.
yield key, f"{item:.5f}".rstrip("0")
elif isinstance(item, list):
yield key, None
for sv in item:
yield from _render_table_item("", sv, prefix=prefix + "-")
elif isinstance(item, Dict):
yield key, None
for sk, sv in item.items():
yield from _render_table_item(str(sk), sv, prefix=prefix + "/")
else:
yield key, item


def _get_dict_as_table_data(
data: Dict,
exclude: Optional[Collection] = None,
upper_keys: Optional[Collection] = None,
):
exclude = exclude or set()
upper_keys = upper_keys or set()

upper = []
lower = []

for key, value in sorted(data.items()):
if key in exclude:
continue

for k, v in _render_table_item(str(key), value):
if key in upper_keys:
upper.append([k, v])
else:
lower.append([k, v])

if not upper:
return lower
elif not lower:
return upper
else:
return upper + lower


# Copied/adjusted from tabulate
AIR_TABULATE_TABLEFMT = TableFormat(
lineabove=Line("╭", "─", "─", "╮"),
linebelowheader=Line("├", "─", "─", "┤"),
linebetweenrows=None,
linebelow=Line("╰", "─", "─", "╯"),
headerrow=DataRow("│", " ", "│"),
datarow=DataRow("│", " ", "│"),
padding=1,
with_header_hide=None,
)


def _print_dict_as_table(
data: Dict,
header: Optional[str] = None,
exclude: Optional[Collection] = None,
division: Optional[Collection] = None,
):
table_data = _get_dict_as_table_data(
data=data, exclude=exclude, upper_keys=division
)

headers = [header, ""] if header else []

if not table_data:
return

print(
tabulate(
table_data,
headers=headers,
colalign=("left", "right"),
tablefmt=AIR_TABULATE_TABLEFMT,
)
)


class ProgressReporter:
"""Periodically prints out status update."""

Expand Down Expand Up @@ -594,6 +683,7 @@ def _print_heartbeat(self, trials, *args):

# These keys are blacklisted for printing out training/tuning intermediate/final result!
BLACKLISTED_KEYS = {
"config",
"date",
"done",
"hostname",
Expand Down Expand Up @@ -650,12 +740,28 @@ class AirResultProgressCallback(Callback):
def __init__(self, verbosity):
self._verbosity = verbosity
self._start_time = time.time()
self._trial_last_printed_results = {}

def _print_result(self, trial, result: Optional[Dict] = None, force: bool = False):
"""Only print result if a different result has been reported, or force=True"""
result = result or trial.last_result

def _print_result(self, trial, result=None):
print(pretty_print(result or trial.last_result, BLACKLISTED_KEYS))
last_result_iter = self._trial_last_printed_results.get(trial.trial_id, -1)
this_iter = result.get(TRAINING_ITERATION, 0)

if this_iter != last_result_iter or force:
_print_dict_as_table(
result,
header=f"{self._addressing_tmpl.format(trial)} result",
exclude=BLACKLISTED_KEYS,
division=AUTO_RESULT_KEYS,
)
self._trial_last_printed_results[trial.trial_id] = this_iter

def _print_config(self, trial):
print(pretty_print(trial.config))
_print_dict_as_table(
trial.config, header=f"{self._addressing_tmpl.format(trial)} config"
)

def on_trial_result(
self,
Expand All @@ -669,13 +775,9 @@ def on_trial_result(
return
curr_time, running_time = _get_time_str(self._start_time, time.time())
print(
" ".join(
[
self._addressing_tmpl.format(trial),
f"finished iter {result[TRAINING_ITERATION]} "
f"at {curr_time} (running for {running_time})",
]
)
f"{self._addressing_tmpl.format(trial)} "
f"finished iteration {result[TRAINING_ITERATION]} "
f"at {curr_time} (running for {running_time})."
)
self._print_result(trial, result)

Expand All @@ -689,13 +791,9 @@ def on_trial_complete(
if trial.last_result and TRAINING_ITERATION in trial.last_result:
finished_iter = trial.last_result[TRAINING_ITERATION]
print(
" ".join(
[
self._addressing_tmpl.format(trial),
f"({finished_iter} iters) "
f"finished at {curr_time} (running for {running_time})",
]
)
f"{self._addressing_tmpl.format(trial)} "
f"completed training after {finished_iter} iterations "
f"at {curr_time} (running for {running_time})."
)
self._print_result(trial)

Expand All @@ -714,30 +812,26 @@ def on_checkpoint(
if trial.last_result and TRAINING_ITERATION in trial.last_result:
saved_iter = trial.last_result[TRAINING_ITERATION]
print(
" ".join(
[
self._addressing_tmpl.format(trial),
f"saved checkpoint for iter {saved_iter}"
f" at {checkpoint.dir_or_data}",
]
)
f"{self._addressing_tmpl.format(trial)} "
f"saved a checkpoint for iteration {saved_iter} "
f"at: {checkpoint.dir_or_data}"
)
print()

def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, **info):
if self._verbosity < self._start_end_verbosity:
return
has_config = bool(trial.config)
print(
" ".join(
[
self._addressing_tmpl.format(trial),
"started with configuration:" if has_config else "started.",
]
)
)

if has_config:
print(
f"{self._addressing_tmpl.format(trial)} " f"started with configuration:"
)
self._print_config(trial)
else:
print(
f"{self._addressing_tmpl.format(trial)} "
f"started without custom configuration."
)


class TuneResultProgressCallback(AirResultProgressCallback):
Expand Down
56 changes: 56 additions & 0 deletions python/ray/tune/tests/output/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_current_best_trial,
_best_trial_str,
_get_trial_table_data,
_get_dict_as_table_data,
)
from ray.tune.experiment.trial import Trial

Expand Down Expand Up @@ -164,5 +165,60 @@ def test_get_trial_table_data_more_than_20():
assert table_data[2].more_info == "... and 5 more PENDING ..."


def test_result_table_no_divison():
data = _get_dict_as_table_data(
{
"b": 6,
"a": 8,
"x": 19.123123123,
"c": 5,
"ignore": 9,
"y": 20,
"z": {"m": 4, "n": {"o": "p"}},
},
exclude={"ignore"},
)

assert data == [
["a", 8],
["b", 6],
["c", 5],
["x", "19.12312"],
["y", 20],
["z", None],
["/m", 4],
["/n", None],
["//o", "p"],
]


def test_result_table_divison():
data = _get_dict_as_table_data(
{
"b": 6,
"a": 8,
"x": 19.123123123,
"c": 5,
"ignore": 9,
"y": 20,
"z": {"m": 4, "n": {"o": "p"}},
},
exclude={"ignore"},
upper_keys={"x", "y", "z"},
)

assert data == [
["x", "19.12312"],
["y", 20],
["z", None],
["/m", 4],
["/n", None],
["//o", "p"],
["a", 8],
["b", 6],
["c", 5],
]


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

0 comments on commit cac19c9

Please sign in to comment.