From f62c4b49d4db84cffb74b63922389dd9e11bccef Mon Sep 17 00:00:00 2001
From: Brian Dubois <bdubois@google.com>
Date: Fri, 20 Mar 2020 08:10:47 -0400
Subject: [PATCH 1/4] Revert "Revert "Add --plugins option to uploader (#3377)"
 (#3400)"

This reverts commit 2b2a976b03777a5c6ac6456370ec8c01b80c35e9.
Subsequent commits will fix the issue that caused the original
revert.
---
 tensorboard/uploader/proto/server_info.proto | 41 +++++++++++------
 tensorboard/uploader/server_info.py          | 26 +++++++++--
 tensorboard/uploader/server_info_test.py     | 46 +++++++++++++++++---
 tensorboard/uploader/uploader_main.py        | 16 ++++++-
 4 files changed, 102 insertions(+), 27 deletions(-)

diff --git a/tensorboard/uploader/proto/server_info.proto b/tensorboard/uploader/proto/server_info.proto
index b221124cd1..e2d43e79fc 100644
--- a/tensorboard/uploader/proto/server_info.proto
+++ b/tensorboard/uploader/proto/server_info.proto
@@ -8,11 +8,21 @@ package tensorboard.service;
 message ServerInfoRequest {
   // Client-side TensorBoard version, per `tensorboard.version.VERSION`.
   string version = 1;
+  // Information about the plugins for which the client wishes to upload data.
+  //
+  // If specified then the list of plugins will be confirmed by the server and
+  // echoed in the PluginControl.allowed_plugins field. Otherwise the server
+  // will return the default set of plugins it supports.
+  //
+  // If one of the plugins is not supported by the server then it will respond
+  // with compatibility verdict VERDICT_ERROR.
+  PluginSpecification plugin_specification = 2;
 }
 
 message ServerInfoResponse {
-  // Primary bottom-line: is the server compatible with the client, and is
-  // there anything that the end user should be aware of?
+  // Primary bottom-line: is the server compatible with the client, can it
+  // serve its request, and is there anything that the end user should be
+  // aware of?
   Compatibility compatibility = 1;
   // Identifier for a gRPC server providing the `TensorBoardExporterService` and
   // `TensorBoardWriterService` services (under the `tensorboard.service` proto
@@ -20,19 +30,15 @@ message ServerInfoResponse {
   ApiServer api_server = 2;
   // How to generate URLs to experiment pages.
   ExperimentUrlFormat url_format = 3;
-  // For which plugins should we upload data? (Even if the uploader is
-  // structurally capable of uploading data from many plugins, we only actually
-  // upload data that can be currently displayed in TensorBoard.dev. Otherwise,
-  // users may be surprised to see that experiments that they uploaded a while
-  // ago and have since shared or published now have extra information that
-  // they didn't realize had been uploaded.)
+  // Information about the plugins for which data should be uploaded.
   //
-  // The client may always choose to upload less data than is permitted by this
-  // field: e.g., if the end user specifies not to upload data for a given
-  // plugin, or the client does not yet support uploading some kind of data.
+  // If PluginSpecification.requested_plugins is specified then
+  // that list of plugins will be confirmed by the server and echoed in the
+  // the response. Otherwise the server will return the default set of
+  // plugins it supports.
   //
-  // If this field is omitted, there are no upfront restrictions on what the
-  // client may send.
+  // The client should only upload data for the plugins in the response even
+  // if it is capable of uploading more data.
   PluginControl plugin_control = 4;
 }
 
@@ -74,8 +80,15 @@ message ExperimentUrlFormat {
   string id_placeholder = 2;
 }
 
+message PluginSpecification {
+  // Plugins for which the client wishes to upload data. These are plugin names
+  // as stored in the the `SummaryMetadata.plugin_data.plugin_name` proto
+  // field.
+  repeated string upload_plugins = 2;
+}
+
 message PluginControl {
-  // Only send data from plugins with these names. These are plugin names as
+  // Plugins for which data should be uploaded. These are plugin names as
   // stored in the the `SummaryMetadata.plugin_data.plugin_name` proto field.
   repeated string allowed_plugins = 1;
 }
diff --git a/tensorboard/uploader/server_info.py b/tensorboard/uploader/server_info.py
index f031830dcc..6416b694e0 100644
--- a/tensorboard/uploader/server_info.py
+++ b/tensorboard/uploader/server_info.py
@@ -21,6 +21,7 @@
 from google.protobuf import message
 import requests
 
+from absl import logging
 from tensorboard import version
 from tensorboard.plugins.scalar import metadata as scalars_metadata
 from tensorboard.uploader.proto import server_info_pb2
@@ -30,19 +31,31 @@
 _REQUEST_TIMEOUT_SECONDS = 10
 
 
-def _server_info_request():
+def _server_info_request(upload_plugins):
+    """Generates a ServerInfoRequest
+
+    Args:
+      upload_plugins: List of plugin names requested by the user and to be
+        verified by the server.
+
+    Returns:
+      A `server_info_pb2.ServerInfoRequest` message.
+    """
     request = server_info_pb2.ServerInfoRequest()
     request.version = version.VERSION
+    request.plugin_specification.upload_plugins[:] = upload_plugins
     return request
 
 
-def fetch_server_info(origin):
+def fetch_server_info(origin, upload_plugins):
     """Fetches server info from a remote server.
 
     Args:
       origin: The server with which to communicate. Should be a string
         like "https://tensorboard.dev", including protocol, host, and (if
         needed) port.
+      upload_plugins: List of plugins names requested by the user and to be
+        verified by the server.
 
     Returns:
       A `server_info_pb2.ServerInfoResponse` message.
@@ -52,7 +65,9 @@ def fetch_server_info(origin):
         communicate with the remote server.
     """
     endpoint = "%s/api/uploader" % origin
-    post_body = _server_info_request().SerializeToString()
+    server_info_request = _server_info_request(upload_plugins)
+    post_body = server_info_request.SerializeToString()
+    logging.info("Requested server info: <%r>", server_info_request)
     try:
         response = requests.post(
             endpoint,
@@ -75,13 +90,15 @@ def fetch_server_info(origin):
         )
 
 
-def create_server_info(frontend_origin, api_endpoint):
+def create_server_info(frontend_origin, api_endpoint, upload_plugins):
     """Manually creates server info given a frontend and backend.
 
     Args:
       frontend_origin: The origin of the TensorBoard.dev frontend, like
         "https://tensorboard.dev" or "http://localhost:8000".
       api_endpoint: As to `server_info_pb2.ApiServer.endpoint`.
+      upload_plugins: List of plugin names requested by the user and to be
+        verified by the server.
 
     Returns:
       A `server_info_pb2.ServerInfoResponse` message.
@@ -95,6 +112,7 @@ def create_server_info(frontend_origin, api_endpoint):
         placeholder = "{%s}" % placeholder
     url_format.template = "%s/experiment/%s/" % (frontend_origin, placeholder)
     url_format.id_placeholder = placeholder
+    result.plugin_control.allowed_plugins[:] = upload_plugins
     return result
 
 
diff --git a/tensorboard/uploader/server_info_test.py b/tensorboard/uploader/server_info_test.py
index 129bf36b2a..e714486b46 100644
--- a/tensorboard/uploader/server_info_test.py
+++ b/tensorboard/uploader/server_info_test.py
@@ -70,12 +70,31 @@ def app(request):
             body = request.get_data()
             request_pb = server_info_pb2.ServerInfoRequest.FromString(body)
             self.assertEqual(request_pb.version, version.VERSION)
+            self.assertEqual(request_pb.plugin_specification.upload_plugins, [])
             return wrappers.BaseResponse(expected_result.SerializeToString())
 
         origin = self._start_server(app)
-        result = server_info.fetch_server_info(origin)
+        result = server_info.fetch_server_info(origin, [])
         self.assertEqual(result, expected_result)
 
+    def test_fetches_with_plugins(self):
+        @wrappers.BaseRequest.application
+        def app(request):
+            body = request.get_data()
+            request_pb = server_info_pb2.ServerInfoRequest.FromString(body)
+            self.assertEqual(request_pb.version, version.VERSION)
+            self.assertEqual(
+                request_pb.plugin_specification.upload_plugins,
+                ["plugin1", "plugin2"],
+            )
+            return wrappers.BaseResponse(
+                server_info_pb2.ServerInfoResponse().SerializeToString()
+            )
+
+        origin = self._start_server(app)
+        result = server_info.fetch_server_info(origin, ["plugin1", "plugin2"])
+        self.assertIsNotNone(result)
+
     def test_econnrefused(self):
         (family, localhost) = _localhost()
         s = socket.socket(family)
@@ -83,7 +102,7 @@ def test_econnrefused(self):
         self.addCleanup(s.close)
         port = s.getsockname()[1]
         with self.assertRaises(server_info.CommunicationError) as cm:
-            server_info.fetch_server_info("http://localhost:%d" % port)
+            server_info.fetch_server_info("http://localhost:%d" % port, [])
         msg = str(cm.exception)
         self.assertIn("Failed to connect to backend", msg)
         if os.name != "nt":
@@ -97,7 +116,7 @@ def app(request):
 
         origin = self._start_server(app)
         with self.assertRaises(server_info.CommunicationError) as cm:
-            server_info.fetch_server_info(origin)
+            server_info.fetch_server_info(origin, [])
         msg = str(cm.exception)
         self.assertIn("Non-OK status from backend (502 Bad Gateway)", msg)
         self.assertIn("very sad", msg)
@@ -110,7 +129,7 @@ def app(request):
 
         origin = self._start_server(app)
         with self.assertRaises(server_info.CommunicationError) as cm:
-            server_info.fetch_server_info(origin)
+            server_info.fetch_server_info(origin, [])
         msg = str(cm.exception)
         self.assertIn("Corrupt response from backend", msg)
         self.assertIn("an unlikely proto", msg)
@@ -123,7 +142,7 @@ def app(request):
             return wrappers.BaseResponse(result.SerializeToString())
 
         origin = self._start_server(app)
-        result = server_info.fetch_server_info(origin)
+        result = server_info.fetch_server_info(origin, [])
         expected_user_agent = "tensorboard/%s" % version.VERSION
         self.assertEqual(result.compatibility.details, expected_user_agent)
 
@@ -131,10 +150,10 @@ def app(request):
 class CreateServerInfoTest(tb_test.TestCase):
     """Tests for `create_server_info`."""
 
-    def test(self):
+    def test_response(self):
         frontend = "http://localhost:8080"
         backend = "localhost:10000"
-        result = server_info.create_server_info(frontend, backend)
+        result = server_info.create_server_info(frontend, backend, [])
 
         expected_compatibility = server_info_pb2.Compatibility()
         expected_compatibility.verdict = server_info_pb2.VERDICT_OK
@@ -152,6 +171,19 @@ def test(self):
         expected_url = "http://localhost:8080/experiment/123/"
         self.assertEqual(actual_url, expected_url)
 
+        self.assertEqual(result.plugin_control.allowed_plugins, [])
+
+    def test_response_with_plugins(self):
+        frontend = "http://localhost:8080"
+        backend = "localhost:10000"
+        result = server_info.create_server_info(
+            frontend, backend, ["plugin1", "plugin2"]
+        )
+
+        self.assertEqual(
+            result.plugin_control.allowed_plugins, ["plugin1", "plugin2"]
+        )
+
 
 class ExperimentUrlTest(tb_test.TestCase):
     """Tests for `experiment_url`."""
diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py
index 546fca3258..75deb6366f 100644
--- a/tensorboard/uploader/uploader_main.py
+++ b/tensorboard/uploader/uploader_main.py
@@ -159,6 +159,16 @@ def _define_flags(parser):
         help="Experiment description. Markdown format.  Max 600 characters.",
     )
 
+    upload.add_argument(
+        "--plugins",
+        type=str,
+        nargs="*",
+        default=[],
+        help="List of plugins for which data should be uploaded. If "
+        "unspecified then data will be uploaded for all plugins supported by "
+        "the server.",
+    )
+
     update_metadata = subparsers.add_parser(
         "update-metadata",
         help="change the name, description, or other user "
@@ -734,8 +744,10 @@ def _get_intent(flags):
 def _get_server_info(flags):
     origin = flags.origin or _DEFAULT_ORIGIN
     if flags.api_endpoint and not flags.origin:
-        return server_info_lib.create_server_info(origin, flags.api_endpoint)
-    server_info = server_info_lib.fetch_server_info(origin)
+        return server_info_lib.create_server_info(
+            origin, flags.api_endpoint, flags.plugins
+        )
+    server_info = server_info_lib.fetch_server_info(origin, flags.plugins)
     # Override with any API server explicitly specified on the command
     # line, but only if the server accepted our initial handshake.
     if flags.api_endpoint and server_info.api_server.endpoint:

From 2266bd76f7640398d027c1ac9e9859453b5ad7d5 Mon Sep 17 00:00:00 2001
From: Brian Dubois <bdubois@google.com>
Date: Fri, 20 Mar 2020 08:36:45 -0400
Subject: [PATCH 2/4] Check for plugins flag before using it.

Most tensorboard CLI subcommands do not define --plugins flag so we need
to explicitly check that FLAGS.plugins exists before using it.
---
 tensorboard/uploader/uploader_main.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py
index 75deb6366f..eabe23e643 100644
--- a/tensorboard/uploader/uploader_main.py
+++ b/tensorboard/uploader/uploader_main.py
@@ -743,11 +743,13 @@ def _get_intent(flags):
 
 def _get_server_info(flags):
     origin = flags.origin or _DEFAULT_ORIGIN
+    plugins = flags.plugins if hasattr(flags, 'plugins') else []
+
     if flags.api_endpoint and not flags.origin:
         return server_info_lib.create_server_info(
-            origin, flags.api_endpoint, flags.plugins
+            origin, flags.api_endpoint, plugins
         )
-    server_info = server_info_lib.fetch_server_info(origin, flags.plugins)
+    server_info = server_info_lib.fetch_server_info(origin, plugins)
     # Override with any API server explicitly specified on the command
     # line, but only if the server accepted our initial handshake.
     if flags.api_endpoint and server_info.api_server.endpoint:

From 0777b730f61e517d7b2f077e58d76a4dcf40ba10 Mon Sep 17 00:00:00 2001
From: Brian Dubois <bdubois@google.com>
Date: Fri, 20 Mar 2020 09:07:10 -0400
Subject: [PATCH 3/4] "black ."

---
 tensorboard/uploader/uploader_main.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py
index eabe23e643..ed9a307e31 100644
--- a/tensorboard/uploader/uploader_main.py
+++ b/tensorboard/uploader/uploader_main.py
@@ -743,7 +743,7 @@ def _get_intent(flags):
 
 def _get_server_info(flags):
     origin = flags.origin or _DEFAULT_ORIGIN
-    plugins = flags.plugins if hasattr(flags, 'plugins') else []
+    plugins = flags.plugins if hasattr(flags, "plugins") else []
 
     if flags.api_endpoint and not flags.origin:
         return server_info_lib.create_server_info(

From fc4597873090a656bdcb96163b3bf7534c104039 Mon Sep 17 00:00:00 2001
From: Brian Dubois <bdubois@google.com>
Date: Fri, 20 Mar 2020 14:13:09 -0400
Subject: [PATCH 4/4] Respond to PR comments.

---
 tensorboard/uploader/uploader_main.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py
index ed9a307e31..79562157af 100644
--- a/tensorboard/uploader/uploader_main.py
+++ b/tensorboard/uploader/uploader_main.py
@@ -158,7 +158,6 @@ def _define_flags(parser):
         default=None,
         help="Experiment description. Markdown format.  Max 600 characters.",
     )
-
     upload.add_argument(
         "--plugins",
         type=str,
@@ -743,7 +742,7 @@ def _get_intent(flags):
 
 def _get_server_info(flags):
     origin = flags.origin or _DEFAULT_ORIGIN
-    plugins = flags.plugins if hasattr(flags, "plugins") else []
+    plugins = getattr(flags, "plugins", [])
 
     if flags.api_endpoint and not flags.origin:
         return server_info_lib.create_server_info(