diff --git a/tensorboard/uploader/proto/server_info.proto b/tensorboard/uploader/proto/server_info.proto index b221124cd1..e2d43e79fc 100644 --- a/tensorboard/uploader/proto/server_info.proto +++ b/tensorboard/uploader/proto/server_info.proto @@ -8,11 +8,21 @@ package tensorboard.service; message ServerInfoRequest { // Client-side TensorBoard version, per `tensorboard.version.VERSION`. string version = 1; + // Information about the plugins for which the client wishes to upload data. + // + // If specified then the list of plugins will be confirmed by the server and + // echoed in the PluginControl.allowed_plugins field. Otherwise the server + // will return the default set of plugins it supports. + // + // If one of the plugins is not supported by the server then it will respond + // with compatibility verdict VERDICT_ERROR. + PluginSpecification plugin_specification = 2; } message ServerInfoResponse { - // Primary bottom-line: is the server compatible with the client, and is - // there anything that the end user should be aware of? + // Primary bottom-line: is the server compatible with the client, can it + // serve its request, and is there anything that the end user should be + // aware of? Compatibility compatibility = 1; // Identifier for a gRPC server providing the `TensorBoardExporterService` and // `TensorBoardWriterService` services (under the `tensorboard.service` proto @@ -20,19 +30,15 @@ message ServerInfoResponse { ApiServer api_server = 2; // How to generate URLs to experiment pages. ExperimentUrlFormat url_format = 3; - // For which plugins should we upload data? (Even if the uploader is - // structurally capable of uploading data from many plugins, we only actually - // upload data that can be currently displayed in TensorBoard.dev. Otherwise, - // users may be surprised to see that experiments that they uploaded a while - // ago and have since shared or published now have extra information that - // they didn't realize had been uploaded.) + // Information about the plugins for which data should be uploaded. // - // The client may always choose to upload less data than is permitted by this - // field: e.g., if the end user specifies not to upload data for a given - // plugin, or the client does not yet support uploading some kind of data. + // If PluginSpecification.requested_plugins is specified then + // that list of plugins will be confirmed by the server and echoed in the + // the response. Otherwise the server will return the default set of + // plugins it supports. // - // If this field is omitted, there are no upfront restrictions on what the - // client may send. + // The client should only upload data for the plugins in the response even + // if it is capable of uploading more data. PluginControl plugin_control = 4; } @@ -74,8 +80,15 @@ message ExperimentUrlFormat { string id_placeholder = 2; } +message PluginSpecification { + // Plugins for which the client wishes to upload data. These are plugin names + // as stored in the the `SummaryMetadata.plugin_data.plugin_name` proto + // field. + repeated string upload_plugins = 2; +} + message PluginControl { - // Only send data from plugins with these names. These are plugin names as + // Plugins for which data should be uploaded. These are plugin names as // stored in the the `SummaryMetadata.plugin_data.plugin_name` proto field. repeated string allowed_plugins = 1; } diff --git a/tensorboard/uploader/server_info.py b/tensorboard/uploader/server_info.py index f031830dcc..6416b694e0 100644 --- a/tensorboard/uploader/server_info.py +++ b/tensorboard/uploader/server_info.py @@ -21,6 +21,7 @@ from google.protobuf import message import requests +from absl import logging from tensorboard import version from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.uploader.proto import server_info_pb2 @@ -30,19 +31,31 @@ _REQUEST_TIMEOUT_SECONDS = 10 -def _server_info_request(): +def _server_info_request(upload_plugins): + """Generates a ServerInfoRequest + + Args: + upload_plugins: List of plugin names requested by the user and to be + verified by the server. + + Returns: + A `server_info_pb2.ServerInfoRequest` message. + """ request = server_info_pb2.ServerInfoRequest() request.version = version.VERSION + request.plugin_specification.upload_plugins[:] = upload_plugins return request -def fetch_server_info(origin): +def fetch_server_info(origin, upload_plugins): """Fetches server info from a remote server. Args: origin: The server with which to communicate. Should be a string like "https://tensorboard.dev", including protocol, host, and (if needed) port. + upload_plugins: List of plugins names requested by the user and to be + verified by the server. Returns: A `server_info_pb2.ServerInfoResponse` message. @@ -52,7 +65,9 @@ def fetch_server_info(origin): communicate with the remote server. """ endpoint = "%s/api/uploader" % origin - post_body = _server_info_request().SerializeToString() + server_info_request = _server_info_request(upload_plugins) + post_body = server_info_request.SerializeToString() + logging.info("Requested server info: <%r>", server_info_request) try: response = requests.post( endpoint, @@ -75,13 +90,15 @@ def fetch_server_info(origin): ) -def create_server_info(frontend_origin, api_endpoint): +def create_server_info(frontend_origin, api_endpoint, upload_plugins): """Manually creates server info given a frontend and backend. Args: frontend_origin: The origin of the TensorBoard.dev frontend, like "https://tensorboard.dev" or "http://localhost:8000". api_endpoint: As to `server_info_pb2.ApiServer.endpoint`. + upload_plugins: List of plugin names requested by the user and to be + verified by the server. Returns: A `server_info_pb2.ServerInfoResponse` message. @@ -95,6 +112,7 @@ def create_server_info(frontend_origin, api_endpoint): placeholder = "{%s}" % placeholder url_format.template = "%s/experiment/%s/" % (frontend_origin, placeholder) url_format.id_placeholder = placeholder + result.plugin_control.allowed_plugins[:] = upload_plugins return result diff --git a/tensorboard/uploader/server_info_test.py b/tensorboard/uploader/server_info_test.py index 129bf36b2a..e714486b46 100644 --- a/tensorboard/uploader/server_info_test.py +++ b/tensorboard/uploader/server_info_test.py @@ -70,12 +70,31 @@ def app(request): body = request.get_data() request_pb = server_info_pb2.ServerInfoRequest.FromString(body) self.assertEqual(request_pb.version, version.VERSION) + self.assertEqual(request_pb.plugin_specification.upload_plugins, []) return wrappers.BaseResponse(expected_result.SerializeToString()) origin = self._start_server(app) - result = server_info.fetch_server_info(origin) + result = server_info.fetch_server_info(origin, []) self.assertEqual(result, expected_result) + def test_fetches_with_plugins(self): + @wrappers.BaseRequest.application + def app(request): + body = request.get_data() + request_pb = server_info_pb2.ServerInfoRequest.FromString(body) + self.assertEqual(request_pb.version, version.VERSION) + self.assertEqual( + request_pb.plugin_specification.upload_plugins, + ["plugin1", "plugin2"], + ) + return wrappers.BaseResponse( + server_info_pb2.ServerInfoResponse().SerializeToString() + ) + + origin = self._start_server(app) + result = server_info.fetch_server_info(origin, ["plugin1", "plugin2"]) + self.assertIsNotNone(result) + def test_econnrefused(self): (family, localhost) = _localhost() s = socket.socket(family) @@ -83,7 +102,7 @@ def test_econnrefused(self): self.addCleanup(s.close) port = s.getsockname()[1] with self.assertRaises(server_info.CommunicationError) as cm: - server_info.fetch_server_info("http://localhost:%d" % port) + server_info.fetch_server_info("http://localhost:%d" % port, []) msg = str(cm.exception) self.assertIn("Failed to connect to backend", msg) if os.name != "nt": @@ -97,7 +116,7 @@ def app(request): origin = self._start_server(app) with self.assertRaises(server_info.CommunicationError) as cm: - server_info.fetch_server_info(origin) + server_info.fetch_server_info(origin, []) msg = str(cm.exception) self.assertIn("Non-OK status from backend (502 Bad Gateway)", msg) self.assertIn("very sad", msg) @@ -110,7 +129,7 @@ def app(request): origin = self._start_server(app) with self.assertRaises(server_info.CommunicationError) as cm: - server_info.fetch_server_info(origin) + server_info.fetch_server_info(origin, []) msg = str(cm.exception) self.assertIn("Corrupt response from backend", msg) self.assertIn("an unlikely proto", msg) @@ -123,7 +142,7 @@ def app(request): return wrappers.BaseResponse(result.SerializeToString()) origin = self._start_server(app) - result = server_info.fetch_server_info(origin) + result = server_info.fetch_server_info(origin, []) expected_user_agent = "tensorboard/%s" % version.VERSION self.assertEqual(result.compatibility.details, expected_user_agent) @@ -131,10 +150,10 @@ def app(request): class CreateServerInfoTest(tb_test.TestCase): """Tests for `create_server_info`.""" - def test(self): + def test_response(self): frontend = "http://localhost:8080" backend = "localhost:10000" - result = server_info.create_server_info(frontend, backend) + result = server_info.create_server_info(frontend, backend, []) expected_compatibility = server_info_pb2.Compatibility() expected_compatibility.verdict = server_info_pb2.VERDICT_OK @@ -152,6 +171,19 @@ def test(self): expected_url = "http://localhost:8080/experiment/123/" self.assertEqual(actual_url, expected_url) + self.assertEqual(result.plugin_control.allowed_plugins, []) + + def test_response_with_plugins(self): + frontend = "http://localhost:8080" + backend = "localhost:10000" + result = server_info.create_server_info( + frontend, backend, ["plugin1", "plugin2"] + ) + + self.assertEqual( + result.plugin_control.allowed_plugins, ["plugin1", "plugin2"] + ) + class ExperimentUrlTest(tb_test.TestCase): """Tests for `experiment_url`.""" diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py index 546fca3258..79562157af 100644 --- a/tensorboard/uploader/uploader_main.py +++ b/tensorboard/uploader/uploader_main.py @@ -158,6 +158,15 @@ def _define_flags(parser): default=None, help="Experiment description. Markdown format. Max 600 characters.", ) + upload.add_argument( + "--plugins", + type=str, + nargs="*", + default=[], + help="List of plugins for which data should be uploaded. If " + "unspecified then data will be uploaded for all plugins supported by " + "the server.", + ) update_metadata = subparsers.add_parser( "update-metadata", @@ -733,9 +742,13 @@ def _get_intent(flags): def _get_server_info(flags): origin = flags.origin or _DEFAULT_ORIGIN + plugins = getattr(flags, "plugins", []) + if flags.api_endpoint and not flags.origin: - return server_info_lib.create_server_info(origin, flags.api_endpoint) - server_info = server_info_lib.fetch_server_info(origin) + return server_info_lib.create_server_info( + origin, flags.api_endpoint, plugins + ) + server_info = server_info_lib.fetch_server_info(origin, plugins) # Override with any API server explicitly specified on the command # line, but only if the server accepted our initial handshake. if flags.api_endpoint and server_info.api_server.endpoint: