Skip to content

Commit

Permalink
hparams: thread experiment ID through data access (#3418)
Browse files Browse the repository at this point in the history
Summary:
All data access in the hparams plugin is now performed via methods that
require an experiment ID. The ID is currently unused except for an
assertion. In a subsequent change, these methods will switch to using
a data provider instead of the multiplexer, at which point the
experiment ID will be required.

Test Plan:
Unit tests suffice for the internals. To check the wiring at the plugin
level, note that all three views of the hparams plugin, including
embedded scalar charts and the download links, still render properly.

wchargin-branch: hparams-thread-eid
  • Loading branch information
wchargin authored and bileschi committed Apr 15, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent a73772f commit 572f9be
Showing 7 changed files with 83 additions and 53 deletions.
81 changes: 47 additions & 34 deletions tensorboard/plugins/hparams/backend_context.py
Original file line number Diff line number Diff line change
@@ -53,11 +53,9 @@ def __init__(self, tb_context, max_domain_discrete_len=10):
Typically, only tests should specify a value for this parameter.
"""
self._tb_context = tb_context
self._experiment_from_tag = None
self._experiment_from_tag_lock = threading.Lock()
self._max_domain_discrete_len = max_domain_discrete_len

def experiment(self):
def experiment(self, experiment_id):
"""Returns the experiment protobuffer defining the experiment.
This method first attempts to find a metadata.EXPERIMENT_TAG tag and
@@ -72,9 +70,9 @@ def experiment(self):
protobuffer can be built (possibly, because the event data has not been
completely loaded yet), returns None.
"""
experiment = self._find_experiment_tag()
experiment = self._find_experiment_tag(experiment_id)
if experiment is None:
return self._compute_experiment_from_runs()
return self._compute_experiment_from_runs(experiment_id)
return experiment

@property
@@ -89,72 +87,87 @@ def multiplexer(self):
def tb_context(self):
return self._tb_context

def hparams_metadata(self):
def hparams_metadata(self, experiment_id):
"""Reads summary metadata for all hparams time series.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
assert isinstance(experiment_id, str), (
experiment_id,
type(experiment_id),
)
return self._deprecated_multiplexer.PluginRunToTagToContent(
metadata.PLUGIN_NAME
)

def scalars_metadata(self):
def scalars_metadata(self, experiment_id):
"""Reads summary metadata for all scalar time series.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
assert isinstance(experiment_id, str), (
experiment_id,
type(experiment_id),
)
return self._deprecated_multiplexer.PluginRunToTagToContent(
scalar_metadata.PLUGIN_NAME
)

def read_scalars(self, run, tag):
def read_scalars(self, experiment_id, run, tag):
"""Reads values for a given scalar time series.
Args:
experiment_id: String.
run: String.
tag: String.
Returns:
A list of `plugin_event_accumulator.TensorEvent` values.
"""
assert isinstance(experiment_id, str), (
experiment_id,
type(experiment_id),
)
return self._deprecated_multiplexer.Tensors(run, tag)

def _find_experiment_tag(self):
def _find_experiment_tag(self, experiment_id):
"""Finds the experiment associcated with the metadata.EXPERIMENT_TAG
tag.
Caches the experiment if it was found.
Returns:
The experiment or None if no such experiment is found.
"""
with self._experiment_from_tag_lock:
if self._experiment_from_tag is None:
mapping = self.hparams_metadata()
for tag_to_content in mapping.values():
if metadata.EXPERIMENT_TAG in tag_to_content:
self._experiment_from_tag = metadata.parse_experiment_plugin_data(
tag_to_content[metadata.EXPERIMENT_TAG]
)
break
return self._experiment_from_tag

def _compute_experiment_from_runs(self):
mapping = self.hparams_metadata(experiment_id)
for tag_to_content in mapping.values():
if metadata.EXPERIMENT_TAG in tag_to_content:
experiment = metadata.parse_experiment_plugin_data(
tag_to_content[metadata.EXPERIMENT_TAG]
)
return experiment
return None

def _compute_experiment_from_runs(self, experiment_id):
"""Computes a minimal Experiment protocol buffer by scanning the
runs."""
hparam_infos = self._compute_hparam_infos()
hparam_infos = self._compute_hparam_infos(experiment_id)
if not hparam_infos:
return None
metric_infos = self._compute_metric_infos()
metric_infos = self._compute_metric_infos(experiment_id)
return api_pb2.Experiment(
hparam_infos=hparam_infos, metric_infos=metric_infos
)

def _compute_hparam_infos(self):
def _compute_hparam_infos(self, experiment_id):
"""Computes a list of api_pb2.HParamInfo from the current run, tag
info.
@@ -167,7 +180,7 @@ def _compute_hparam_infos(self):
Returns:
A list of api_pb2.HParamInfo messages.
"""
run_to_tag_to_content = self.hparams_metadata()
run_to_tag_to_content = self.hparams_metadata(experiment_id)
# Construct a dict mapping an hparam name to its list of values.
hparams = collections.defaultdict(list)
for tag_to_content in run_to_tag_to_content.values():
@@ -236,13 +249,13 @@ def _compute_hparam_info_from_values(self, name, values):

return result

def _compute_metric_infos(self):
def _compute_metric_infos(self, experiment_id):
return (
api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag))
for tag, group in self._compute_metric_names()
for tag, group in self._compute_metric_names(experiment_id)
)

def _compute_metric_names(self):
def _compute_metric_names(self, experiment_id):
"""Computes the list of metric names from all the scalar (run, tag)
pairs.
@@ -268,9 +281,9 @@ def _compute_metric_names(self):
A python list containing pairs. Each pair is a (tag, group) pair
representing a metric name used in some session.
"""
session_runs = self._build_session_runs_set()
session_runs = self._build_session_runs_set(experiment_id)
metric_names_set = set()
run_to_tag_to_content = self.scalars_metadata()
run_to_tag_to_content = self.scalars_metadata(experiment_id)
for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
session = _find_longest_parent_path(session_runs, run)
if not session:
@@ -288,9 +301,9 @@ def _compute_metric_names(self):
metric_names_list.sort()
return metric_names_list

def _build_session_runs_set(self):
def _build_session_runs_set(self, experiment_id):
result = set()
run_to_tag_to_content = self.hparams_metadata()
run_to_tag_to_content = self.hparams_metadata(experiment_id)
for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
if metadata.SESSION_START_INFO_TAG in tag_to_content:
result.add(run)
8 changes: 4 additions & 4 deletions tensorboard/plugins/hparams/backend_context_test.py
Original file line number Diff line number Diff line change
@@ -113,7 +113,7 @@ def test_experiment_with_experiment_tag(self):
}
}
ctxt = backend_context.Context(self._mock_tb_context)
self.assertProtoEquals(experiment, ctxt.experiment())
self.assertProtoEquals(experiment, ctxt.experiment(experiment_id="123"))

def test_experiment_without_experiment_tag(self):
self.session_1_start_info_ = """
@@ -168,7 +168,7 @@ def test_experiment_without_experiment_tag(self):
}
"""
ctxt = backend_context.Context(self._mock_tb_context)
actual_exp = ctxt.experiment()
actual_exp = ctxt.experiment(experiment_id="123")
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

@@ -230,7 +230,7 @@ def test_experiment_without_experiment_tag_different_hparam_types(self):
}
"""
ctxt = backend_context.Context(self._mock_tb_context)
actual_exp = ctxt.experiment()
actual_exp = ctxt.experiment(experiment_id="123")
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

@@ -285,7 +285,7 @@ def test_experiment_without_experiment_tag_many_distinct_values(self):
ctxt = backend_context.Context(
self._mock_tb_context, max_domain_discrete_len=1
)
actual_exp = ctxt.experiment()
actual_exp = ctxt.experiment(experiment_id="123")
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

6 changes: 4 additions & 2 deletions tensorboard/plugins/hparams/get_experiment.py
Original file line number Diff line number Diff line change
@@ -25,21 +25,23 @@
class Handler(object):
"""Handles a GetExperiment request."""

def __init__(self, context):
def __init__(self, context, experiment_id):
"""Constructor.
Args:
context: A backend_context.Context instance.
experiment_id: A string, as from `plugin_util.experiment_id`.
"""
self._context = context
self._experiment_id = experiment_id

def run(self):
"""Handles the request specified on construction.
Returns:
An Experiment object.
"""
experiment = self._context.experiment()
experiment = self._context.experiment(self._experiment_id)
if experiment is None:
raise error.HParamsError(
"Can't find an HParams-plugin experiment data in"
17 changes: 11 additions & 6 deletions tensorboard/plugins/hparams/hparams_plugin.py
Original file line number Diff line number Diff line change
@@ -82,6 +82,7 @@ def frontend_metadata(self):
# ---- /download_data- -------------------------------------------------------
@wrappers.Request.application
def download_data_route(self, request):
experiment_id = plugin_util.experiment_id(request.environ)
try:
response_format = request.args.get("format")
columns_visibility = json.loads(
@@ -91,9 +92,11 @@ def download_data_route(self, request):
request, api_pb2.ListSessionGroupsRequest
)
session_groups = list_session_groups.Handler(
self._context, request_proto
self._context, experiment_id, request_proto
).run()
experiment = get_experiment.Handler(
self._context, experiment_id
).run()
experiment = get_experiment.Handler(self._context).run()
body, mime_type = download_data.Handler(
self._context,
experiment,
@@ -109,6 +112,7 @@ def download_data_route(self, request):
# ---- /experiment -----------------------------------------------------------
@wrappers.Request.application
def get_experiment_route(self, request):
experiment_id = plugin_util.experiment_id(request.environ)
try:
# This backend currently ignores the request parameters, but (for a POST)
# we must advance the input stream to skip them -- otherwise the next HTTP
@@ -117,7 +121,7 @@ def get_experiment_route(self, request):
return http_util.Respond(
request,
json_format.MessageToJson(
get_experiment.Handler(self._context).run(),
get_experiment.Handler(self._context, experiment_id).run(),
including_default_value_fields=True,
),
"application/json",
@@ -129,6 +133,7 @@ def get_experiment_route(self, request):
# ---- /session_groups -------------------------------------------------------
@wrappers.Request.application
def list_session_groups_route(self, request):
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListSessionGroupsRequest
@@ -137,7 +142,7 @@ def list_session_groups_route(self, request):
request,
json_format.MessageToJson(
list_session_groups.Handler(
self._context, request_proto
self._context, experiment_id, request_proto
).run(),
including_default_value_fields=True,
),
@@ -150,7 +155,7 @@ def list_session_groups_route(self, request):
# ---- /metric_evals ---------------------------------------------------------
@wrappers.Request.application
def list_metric_evals_route(self, request):
experiment = plugin_util.experiment_id(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListMetricEvalsRequest
@@ -162,7 +167,7 @@ def list_metric_evals_route(self, request):
request,
json.dumps(
list_metric_evals.Handler(
request_proto, scalars_plugin, experiment
request_proto, scalars_plugin, experiment_id
).run()
),
"application/json",
15 changes: 11 additions & 4 deletions tensorboard/plugins/hparams/list_session_groups.py
Original file line number Diff line number Diff line change
@@ -35,20 +35,22 @@
class Handler(object):
"""Handles a ListSessionGroups request."""

def __init__(self, context, request):
def __init__(self, context, experiment_id, request):
"""Constructor.
Args:
context: A backend_context.Context instance.
experiment_id: A string, as from `plugin_util.experiment_id`.
request: A ListSessionGroupsRequest protobuf.
"""
self._context = context
self._experiment_id = experiment_id
self._request = request
self._extractors = _create_extractors(request.col_params)
self._filters = _create_filters(request.col_params, self._extractors)
# Since an context.experiment() call may search through all the runs, we
# cache it here.
self._experiment = context.experiment()
self._experiment = context.experiment(experiment_id)

def run(self):
"""Handles the request specified on construction.
@@ -72,7 +74,9 @@ def _build_session_groups(self):
# in the 'groups_by_name' dict. We create the SessionGroup object, if this
# is the first session of that group we encounter.
groups_by_name = {}
run_to_tag_to_content = self._context.hparams_metadata()
run_to_tag_to_content = self._context.hparams_metadata(
self._experiment_id
)
for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
if metadata.SESSION_START_INFO_TAG not in tag_to_content:
continue
@@ -157,7 +161,10 @@ def _build_session_metric_values(self, session_name):
metric_name = metric_info.name
try:
metric_eval = metrics.last_metric_eval(
self._context, session_name, metric_name
self._context,
self._experiment_id,
session_name,
metric_name,
)
except KeyError:
# It's ok if we don't find the metric in the session.
4 changes: 3 additions & 1 deletion tensorboard/plugins/hparams/list_session_groups_test.py
Original file line number Diff line number Diff line change
@@ -1132,7 +1132,9 @@ def _run_handler(self, request):
request_proto = api_pb2.ListSessionGroupsRequest()
text_format.Merge(request, request_proto)
handler = list_session_groups.Handler(
backend_context.Context(self._mock_tb_context), request_proto
context=backend_context.Context(self._mock_tb_context),
experiment_id="123",
request=request_proto,
)
response = handler.run()
# Sort the metric values repeated field in each session group to
Loading
Oops, something went wrong.

0 comments on commit 572f9be

Please sign in to comment.