From 85fd07e8289e74d095f8d75e0c4439901196d7e8 Mon Sep 17 00:00:00 2001 From: Adrian Rumpold Date: Tue, 12 Oct 2021 15:36:51 +0200 Subject: [PATCH] Allow passing custom gRPC channel credentials to FlyteRemote (#693) --- flytekit/remote/remote.py | 8 ++++++- tests/flytekit/unit/remote/test_remote.py | 28 +++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 37dbba1083..41a128caf9 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -10,6 +10,7 @@ from dataclasses import asdict, dataclass from datetime import datetime, timedelta +import grpc from flyteidl.core import literals_pb2 as literals_pb2 from flytekit.clients.friendly import SynchronousFlyteClient @@ -122,12 +123,14 @@ def from_config( default_project: typing.Optional[str] = None, default_domain: typing.Optional[str] = None, config_file_path: typing.Optional[str] = None, + grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None, ) -> FlyteRemote: """Create a FlyteRemote object using flyte configuration variables and/or environment variable overrides. :param default_project: default project to use when fetching or executing flyte entities. :param default_domain: default domain to use when fetching or executing flyte entities. :param config_file_path: config file to use when connecting to flyte admin. we will use '~/.flyte/config' by default. + :param grpc_credentials: gRPC channel credentials for connecting to flyte admin as returned by :func:`grpc.ssl_channel_credentials` """ if config_file_path is None: @@ -161,6 +164,7 @@ def from_config( raw_output_data_config=( common_models.RawOutputDataConfig(raw_output_data_prefix) if raw_output_data_prefix else None ), + grpc_credentials=grpc_credentials, ) def __init__( @@ -176,6 +180,7 @@ def __init__( annotations: typing.Optional[common_models.Annotations] = None, image_config: typing.Optional[ImageConfig] = None, raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None, + grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None, ): """Initialize a FlyteRemote object. @@ -190,12 +195,13 @@ def __init__( :param annotations: annotation config :param image_config: image config :param raw_output_data_config: location for offloaded data, e.g. in S3 + :param grpc_credentials: gRPC channel credentials for connecting to flyte admin as returned by :func:`grpc.ssl_channel_credentials` """ remote_logger.warning("This feature is still in beta. Its interface and UX is subject to change.") if flyte_admin_url is None: raise user_exceptions.FlyteAssertion("Cannot find flyte admin url in config file.") - self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure) + self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure, credentials=grpc_credentials) # read config files, env vars, host, ssl options for admin client self._flyte_admin_url = flyte_admin_url diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 52e763d80f..873f67d752 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -240,3 +240,31 @@ def test_form_config(mock_insecure, mock_url): assert remote._insecure is True assert remote.default_project == "p1" assert remote.default_domain == "d1" + + +@patch("flytekit.clients.raw._ssl_channel_credentials") +@patch("flytekit.clients.raw._secure_channel") +@patch("flytekit.configuration.platform.URL") +@patch("flytekit.configuration.platform.INSECURE") +def test_explicit_grpc_channel_credentials(mock_insecure, mock_url, mock_secure_channel, mock_ssl_channel_credentials): + mock_url.get.return_value = "localhost" + mock_insecure.get.return_value = False + + # Default mode, no explicit channel credentials + mock_ssl_channel_credentials.reset_mock() + _ = FlyteRemote.from_config("project", "domain") + + assert mock_ssl_channel_credentials.called + + mock_secure_channel.reset_mock() + mock_ssl_channel_credentials.reset_mock() + + # Explicit channel credentials + from grpc import ssl_channel_credentials + + credentials = ssl_channel_credentials(b"TEST CERTIFICATE") + + _ = FlyteRemote.from_config("project", "domain", grpc_credentials=credentials) + assert mock_secure_channel.called + assert mock_secure_channel.call_args[0][1] == credentials + assert not mock_ssl_channel_credentials.called