Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hparams: allow setting trial ID #2442

Merged
merged 2 commits into from
Jul 19, 2019
Merged

hparams: allow setting trial ID #2442

merged 2 commits into from
Jul 19, 2019

Conversation

wchargin
Copy link
Contributor

@wchargin wchargin commented Jul 18, 2019

Summary:
Resolves #2440. See #1998 for discussion.

Test Plan:
The hparams demo still does not specify trial IDs (intentionally, as
this is the usual path). But apply the following patch—

diff --git a/tensorboard/plugins/hparams/hparams_demo.py b/tensorboard/plugins/hparams/hparams_demo.py
index ac4e762b..38b2b122 100644
--- a/tensorboard/plugins/hparams/hparams_demo.py
+++ b/tensorboard/plugins/hparams/hparams_demo.py
@@ -160,7 +160,7 @@ def model_fn(hparams, seed):
   return model
 
 
-def run(data, base_logdir, session_id, hparams):
+def run(data, base_logdir, session_id, trial_id, hparams):
   """Run a training/validation session.
 
   Flags must have been parsed for this function to behave.
@@ -179,7 +179,7 @@ def run(data, base_logdir, session_id, hparams):
       update_freq=flags.FLAGS.summary_freq,
       profile_batch=0,  # workaround for issue #2084
   )
-  hparams_callback = hp.KerasCallback(logdir, hparams)
+  hparams_callback = hp.KerasCallback(logdir, hparams, trial_id=trial_id)
   ((x_train, y_train), (x_test, y_test)) = data
   result = model.fit(
       x=x_train,
@@ -235,6 +235,7 @@ def run_all(logdir, verbose=False):
           data=data,
           base_logdir=logdir,
           session_id=session_id,
+          trial_id="trial-%d" % group_index,
           hparams=hparams,
       )
 

—and then run //tensorboard/plugins/hparams:hparams_demo, and observe
that the HParams dashboard renders a “Trial ID” column with the
specified IDs:

Screenshot of new version of HParams dashboard

wchargin-branch: hparams-trial-id

Summary:
Resolves #2440. See #1998 for discussion.

Test Plan:
The hparams demo still does not specify trial IDs (intentionally, as
this is the usual path). But apply the following patch—

```diff
diff --git a/tensorboard/plugins/hparams/hparams_demo.py b/tensorboard/plugins/hparams/hparams_demo.py
index ac4e762b..d0279f27 100644
--- a/tensorboard/plugins/hparams/hparams_demo.py
+++ b/tensorboard/plugins/hparams/hparams_demo.py
@@ -63,7 +63,7 @@ flags.DEFINE_integer(
 )
 flags.DEFINE_integer(
     "num_epochs",
-    5,
+    1,
     "Number of epochs per trial.",
 )
 
@@ -160,7 +160,7 @@ def model_fn(hparams, seed):
   return model
 
 
-def run(data, base_logdir, session_id, hparams):
+def run(data, base_logdir, session_id, trial_id, hparams):
   """Run a training/validation session.
 
   Flags must have been parsed for this function to behave.
@@ -179,7 +179,7 @@ def run(data, base_logdir, session_id, hparams):
       update_freq=flags.FLAGS.summary_freq,
       profile_batch=0,  # workaround for issue #2084
   )
-  hparams_callback = hp.KerasCallback(logdir, hparams)
+  hparams_callback = hp.KerasCallback(logdir, hparams, trial_id=trial_id)
   ((x_train, y_train), (x_test, y_test)) = data
   result = model.fit(
       x=x_train,
@@ -235,6 +235,7 @@ def run_all(logdir, verbose=False):
           data=data,
           base_logdir=logdir,
           session_id=session_id,
+          trial_id="trial-%d" % group_index,
           hparams=hparams,
       )
 
```

—and then run `//tensorboard/plugins/hparams:hparams_demo`, and observe
that the HParams dashboard renders a “Trial ID” column with the
specified IDs:

![Screenshot of new version of HParams dashboard]

[1]: https://user-images.githubusercontent.com/4317806/61491024-1fb01280-a963-11e9-8a47-35e0a01f3691.png

wchargin-branch: hparams-trial-id
def _derive_session_group_name(hparams):
def _derive_session_group_name(trial_id, hparams):
if trial_id is not None:
if not isinstance(trial_id, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes; that’s probably better. Might as well fully support Python 2.

start_time_secs=start_time_secs,
)
return _write_summary("hparams", pb)


def hparams_pb(hparams, start_time_secs=None):
def hparams_pb(hparams, trial_id=None, start_time_secs=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no action required: if we shipped this API as part of 1.14, and if a user don't use the kwarg, this change will break them, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct; you’re supposed to use use the kwarg.

(But we wouldn’t ship this as part of 1.14 anyway, because it’s a new
feature.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can change these to be keyword-only arguments once we drop
Python 2 support, and likewise with TBContext and friends.

Copy link
Contributor Author

@wchargin wchargin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change to string_types.

start_time_secs=start_time_secs,
)
return _write_summary("hparams", pb)


def hparams_pb(hparams, start_time_secs=None):
def hparams_pb(hparams, trial_id=None, start_time_secs=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct; you’re supposed to use use the kwarg.

(But we wouldn’t ship this as part of 1.14 anyway, because it’s a new
feature.)

def _derive_session_group_name(hparams):
def _derive_session_group_name(trial_id, hparams):
if trial_id is not None:
if not isinstance(trial_id, str):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes; that’s probably better. Might as well fully support Python 2.

wchargin-source: 550ad71d9d292f8caab432487533630c1362313b
wchargin-branch: hparams-trial-id
@wchargin wchargin merged commit 192ab46 into master Jul 19, 2019
@wchargin wchargin deleted the wchargin-hparams-trial-id branch July 19, 2019 17:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

HParams Session Group Name.
3 participants