Skip to content

Commit

Permalink
Allow passing custom gRPC channel credentials to FlyteRemote (flyteor…
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianoKF authored Oct 12, 2021
1 parent bf7cd5e commit 85fd07e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
8 changes: 7 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta

import grpc
from flyteidl.core import literals_pb2 as literals_pb2

from flytekit.clients.friendly import SynchronousFlyteClient
Expand Down Expand Up @@ -122,12 +123,14 @@ def from_config(
default_project: typing.Optional[str] = None,
default_domain: typing.Optional[str] = None,
config_file_path: typing.Optional[str] = None,
grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None,
) -> FlyteRemote:
"""Create a FlyteRemote object using flyte configuration variables and/or environment variable overrides.
:param default_project: default project to use when fetching or executing flyte entities.
:param default_domain: default domain to use when fetching or executing flyte entities.
:param config_file_path: config file to use when connecting to flyte admin. we will use '~/.flyte/config' by default.
:param grpc_credentials: gRPC channel credentials for connecting to flyte admin as returned by :func:`grpc.ssl_channel_credentials`
"""

if config_file_path is None:
Expand Down Expand Up @@ -161,6 +164,7 @@ def from_config(
raw_output_data_config=(
common_models.RawOutputDataConfig(raw_output_data_prefix) if raw_output_data_prefix else None
),
grpc_credentials=grpc_credentials,
)

def __init__(
Expand All @@ -176,6 +180,7 @@ def __init__(
annotations: typing.Optional[common_models.Annotations] = None,
image_config: typing.Optional[ImageConfig] = None,
raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None,
grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None,
):
"""Initialize a FlyteRemote object.
Expand All @@ -190,12 +195,13 @@ def __init__(
:param annotations: annotation config
:param image_config: image config
:param raw_output_data_config: location for offloaded data, e.g. in S3
:param grpc_credentials: gRPC channel credentials for connecting to flyte admin as returned by :func:`grpc.ssl_channel_credentials`
"""
remote_logger.warning("This feature is still in beta. Its interface and UX is subject to change.")
if flyte_admin_url is None:
raise user_exceptions.FlyteAssertion("Cannot find flyte admin url in config file.")

self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure)
self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure, credentials=grpc_credentials)

# read config files, env vars, host, ssl options for admin client
self._flyte_admin_url = flyte_admin_url
Expand Down
28 changes: 28 additions & 0 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,31 @@ def test_form_config(mock_insecure, mock_url):
assert remote._insecure is True
assert remote.default_project == "p1"
assert remote.default_domain == "d1"


@patch("flytekit.clients.raw._ssl_channel_credentials")
@patch("flytekit.clients.raw._secure_channel")
@patch("flytekit.configuration.platform.URL")
@patch("flytekit.configuration.platform.INSECURE")
def test_explicit_grpc_channel_credentials(mock_insecure, mock_url, mock_secure_channel, mock_ssl_channel_credentials):
mock_url.get.return_value = "localhost"
mock_insecure.get.return_value = False

# Default mode, no explicit channel credentials
mock_ssl_channel_credentials.reset_mock()
_ = FlyteRemote.from_config("project", "domain")

assert mock_ssl_channel_credentials.called

mock_secure_channel.reset_mock()
mock_ssl_channel_credentials.reset_mock()

# Explicit channel credentials
from grpc import ssl_channel_credentials

credentials = ssl_channel_credentials(b"TEST CERTIFICATE")

_ = FlyteRemote.from_config("project", "domain", grpc_credentials=credentials)
assert mock_secure_channel.called
assert mock_secure_channel.call_args[0][1] == credentials
assert not mock_ssl_channel_credentials.called

0 comments on commit 85fd07e

Please sign in to comment.