Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hparams: allow setting trial ID #2442

Merged
merged 2 commits into from
Jul 19, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tensorboard/plugins/hparams/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Callback(tf.keras.callbacks.Callback):
NOTE: This callback only works in TensorFlow eager mode.
"""

def __init__(self, writer, hparams):
def __init__(self, writer, hparams, trial_id=None):
"""Create a callback for logging hyperparameters to TensorBoard.

As with the standard `tf.keras.callbacks.TensorBoard` class, each
Expand All @@ -51,6 +51,9 @@ def __init__(self, writer, hparams):
in an experiment, or the `HParam` objects themselves. Values
should be Python `bool`, `int`, `float`, or `string` values,
depending on the type of the hyperparameter.
trial_id: An optional `str` ID for the set of hyperparameter
values used in this trial. Defaults to a hash of the
hyperparameters.

Raises:
ValueError: If two entries in `hparams` share the same
Expand All @@ -60,7 +63,8 @@ def __init__(self, writer, hparams):
# timestamp is correct. But create a "dry-run" first to fail fast in
# case the `hparams` are invalid.
self._hparams = dict(hparams)
summary_v2.hparams_pb(self._hparams)
self._trial_id = trial_id
summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id)
if writer is None:
raise TypeError("writer must be a `SummaryWriter` or `str`, not None")
elif isinstance(writer, str):
Expand All @@ -82,7 +86,7 @@ def _get_writer(self):
def on_train_begin(self, logs=None):
del logs # unused
with self._get_writer().as_default():
summary_v2.hparams(self._hparams)
summary_v2.hparams(self._hparams, trial_id=self._trial_id)

def on_train_end(self, logs=None):
del logs # unused
Expand Down
12 changes: 8 additions & 4 deletions tensorboard/plugins/hparams/keras_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _initialize_model(self, writer):
tf.keras.layers.Dense(1, activation="sigmoid"),
])
self.model.compile(loss="mse", optimizer=self.hparams["optimizer"])
self.callback = keras.Callback(writer, self.hparams)
self.trial_id = "my_trial"
self.callback = keras.Callback(writer, self.hparams, trial_id=self.trial_id)

def test_eager(self):
def mock_time():
Expand Down Expand Up @@ -99,13 +100,11 @@ def mock_time():
start_pb.start_time_secs = 1234.5
end_pb.end_time_secs = 6789.0

start_pb.group_name = "do_not_care"

expected_start_pb = plugin_data_pb2.SessionStartInfo()
text_format.Merge(
"""
start_time_secs: 1234.5
group_name: "do_not_care"
group_name: "my_trial"
hparams {
key: "optimizer"
value {
Expand Down Expand Up @@ -186,6 +185,11 @@ def test_duplicate_hparam_names_from_two_objects(self):
self, ValueError, "multiple values specified for hparam 'foo'"):
keras.Callback(self.get_temp_dir(), hparams)

def test_invalid_trial_id(self):
with six.assertRaisesRegex(
self, TypeError, "`trial_id` should be a `str`, but got: 12"):
keras.Callback(self.get_temp_dir(), {}, trial_id=12)


if __name__ == "__main__":
tf.test.main()
17 changes: 13 additions & 4 deletions tensorboard/plugins/hparams/summary_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tensorboard.plugins.hparams import plugin_data_pb2


def hparams(hparams, start_time_secs=None):
def hparams(hparams, trial_id=None, start_time_secs=None):
# NOTE: Keep docs in sync with `hparams_pb` below.
"""Write hyperparameter values for a single trial.

Expand All @@ -46,6 +46,8 @@ def hparams(hparams, start_time_secs=None):
experiment, or the `HParam` objects themselves. Values should be
Python `bool`, `int`, `float`, or `string` values, depending on
the type of the hyperparameter.
trial_id: An optional `str` ID for the set of hyperparameter values
used in this trial. Defaults to a hash of the hyperparameters.
start_time_secs: The time that this trial started training, as
seconds since epoch. Defaults to the current time.

Expand All @@ -55,12 +57,13 @@ def hparams(hparams, start_time_secs=None):
"""
pb = hparams_pb(
hparams=hparams,
trial_id=trial_id,
start_time_secs=start_time_secs,
)
return _write_summary("hparams", pb)


def hparams_pb(hparams, start_time_secs=None):
def hparams_pb(hparams, trial_id=None, start_time_secs=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no action required: if we shipped this API as part of 1.14, and if a user don't use the kwarg, this change will break them, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct; you’re supposed to use use the kwarg.

(But we wouldn’t ship this as part of 1.14 anyway, because it’s a new
feature.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can change these to be keyword-only arguments once we drop
Python 2 support, and likewise with TBContext and friends.

# NOTE: Keep docs in sync with `hparams` above.
"""Create a summary encoding hyperparameter values for a single trial.

Expand All @@ -70,6 +73,8 @@ def hparams_pb(hparams, start_time_secs=None):
experiment, or the `HParam` objects themselves. Values should be
Python `bool`, `int`, `float`, or `string` values, depending on
the type of the hyperparameter.
trial_id: An optional `str` ID for the set of hyperparameter values
used in this trial. Defaults to a hash of the hyperparameters.
start_time_secs: The time that this trial started training, as
seconds since epoch. Defaults to the current time.

Expand All @@ -79,7 +84,7 @@ def hparams_pb(hparams, start_time_secs=None):
if start_time_secs is None:
start_time_secs = time.time()
hparams = _normalize_hparams(hparams)
group_name = _derive_session_group_name(hparams)
group_name = _derive_session_group_name(trial_id, hparams)

session_start_info = plugin_data_pb2.SessionStartInfo(
group_name=group_name,
Expand Down Expand Up @@ -199,7 +204,11 @@ def _normalize_hparams(hparams):
return result


def _derive_session_group_name(hparams):
def _derive_session_group_name(trial_id, hparams):
if trial_id is not None:
if not isinstance(trial_id, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes; that’s probably better. Might as well fully support Python 2.

raise TypeError("`trial_id` should be a `str`, but got: %r" % (trial_id,))
return trial_id
# Use `json.dumps` rather than `str` to ensure invariance under string
# type (incl. across Python versions) and dict iteration order.
jparams = json.dumps(hparams, sort_keys=True, separators=(",", ":"))
Expand Down
38 changes: 30 additions & 8 deletions tensorboard/plugins/hparams/summary_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def setUp(self):
"dropout": 0.3,
}
self.start_time_secs = 123.45
self.group_name = "big_sha"
self.trial_id = "psl27"

self.expected_session_start_pb = plugin_data_pb2.SessionStartInfo()
text_format.Merge(
Expand All @@ -93,13 +93,13 @@ def setUp(self):
hparams { key: "who_knows_what" value { string_value: "???" } }
hparams { key: "magic" value { bool_value: true } }
hparams { key: "dropout" value { number_value: 0.3 } }
group_name: "big_sha" # we'll ignore this field when asserting equality
""",
self.expected_session_start_pb,
)
self.expected_session_start_pb.group_name = self.trial_id
self.expected_session_start_pb.start_time_secs = self.start_time_secs

def _check_summary(self, summary_pb):
def _check_summary(self, summary_pb, check_group_name=False):
"""Test that a summary contains exactly the expected hparams PB."""
values = summary_pb.value
self.assertEqual(len(values), 1, values)
Expand All @@ -110,18 +110,27 @@ def _check_summary(self, summary_pb):
)
plugin_content = actual_value.metadata.plugin_data.content
info_pb = metadata.parse_session_start_info_plugin_data(plugin_content)
# Ignore the `group_name` field; its properties are checked separately.
info_pb.group_name = self.expected_session_start_pb.group_name
# Usually ignore the `group_name` field; its properties are checked
# separately.
if not check_group_name:
info_pb.group_name = self.expected_session_start_pb.group_name
self.assertEqual(info_pb, self.expected_session_start_pb)

def _check_logdir(self, logdir):
def _check_logdir(self, logdir, check_group_name=False):
"""Test that the hparams summary was written to `logdir`."""
self._check_summary(_get_unique_summary(self, logdir))
self._check_summary(
_get_unique_summary(self, logdir),
check_group_name=check_group_name,
)

@requires_tf
def test_eager(self):
with tf.compat.v2.summary.create_file_writer(self.logdir).as_default():
result = hp.hparams(self.hparams, start_time_secs=self.start_time_secs)
result = hp.hparams(
self.hparams,
trial_id=self.trial_id,
start_time_secs=self.start_time_secs,
)
self.assertTrue(result)
self._check_logdir(self.logdir)

Expand Down Expand Up @@ -152,6 +161,19 @@ def test_pb_is_tensorboard_copy_of_proto(self):
if tf is not None:
self.assertNotIsInstance(result, tf.compat.v1.Summary)

def test_pb_explicit_trial_id(self):
result = hp.hparams_pb(
self.hparams,
trial_id=self.trial_id,
start_time_secs=self.start_time_secs,
)
self._check_summary(result, check_group_name=True)

def test_pb_invalid_trial_id(self):
with six.assertRaisesRegex(
self, TypeError, "`trial_id` should be a `str`, but got: 12"):
hp.hparams_pb(self.hparams, trial_id=12)

def assert_hparams_summaries_equal(self, summary_1, summary_2):
def canonical(summary):
"""Return a canonical form for `summary`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
items="[[sessionGroups]]">
<vaadin-grid-column flex-grow="0" width="10em" resizable>
<template class="header">
<div class="table-header table-cell">Session Group Name.</div>
<div class="table-header table-cell">Trial ID</div>
</template>
<template>
<div class="table-cell">[[item.name]]</div>
Expand Down