Skip to content

Commit

Permalink
feat: add support for using the emulator programatically (#87)
Browse files Browse the repository at this point in the history
* feat: add support for using the emulator programatically

* always set credentials when SPANNER_EMULATOR_HOST is set

* address PR comments

Co-authored-by: larkee <larkee@users.noreply.github.com>
  • Loading branch information
larkee and larkee authored May 26, 2020
1 parent f33c866 commit b22630b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 25 deletions.
36 changes: 22 additions & 14 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings

from google.api_core.gapic_v1 import client_info
from google.auth.credentials import AnonymousCredentials
import google.api_core.client_options

from google.cloud.spanner_admin_instance_v1.gapic.transports import (
Expand Down Expand Up @@ -173,19 +174,27 @@ def __init__(
client_options=None,
query_options=None,
):
self._emulator_host = _get_spanner_emulator_host()

if client_options and type(client_options) == dict:
self._client_options = google.api_core.client_options.from_dict(
client_options
)
else:
self._client_options = client_options

if self._emulator_host:
credentials = AnonymousCredentials()
elif isinstance(credentials, AnonymousCredentials):
self._emulator_host = self._client_options.api_endpoint

# NOTE: This API has no use for the _http argument, but sending it
# will have no impact since the _http() @property only lazily
# creates a working HTTP object.
super(Client, self).__init__(
project=project, credentials=credentials, _http=None
)
self._client_info = client_info
if client_options and type(client_options) == dict:
self._client_options = google.api_core.client_options.from_dict(
client_options
)
else:
self._client_options = client_options

env_query_options = ExecuteSqlRequest.QueryOptions(
optimizer_version=_get_spanner_optimizer_version()
Expand All @@ -198,9 +207,8 @@ def __init__(
warnings.warn(_USER_AGENT_DEPRECATED, DeprecationWarning, stacklevel=2)
self.user_agent = user_agent

if _get_spanner_emulator_host() is not None and (
"http://" in _get_spanner_emulator_host()
or "https://" in _get_spanner_emulator_host()
if self._emulator_host is not None and (
"http://" in self._emulator_host or "https://" in self._emulator_host
):
warnings.warn(_EMULATOR_HOST_HTTP_SCHEME)

Expand Down Expand Up @@ -237,9 +245,9 @@ def project_name(self):
def instance_admin_api(self):
"""Helper for session-related API calls."""
if self._instance_admin_api is None:
if _get_spanner_emulator_host() is not None:
if self._emulator_host is not None:
transport = instance_admin_grpc_transport.InstanceAdminGrpcTransport(
channel=grpc.insecure_channel(_get_spanner_emulator_host())
channel=grpc.insecure_channel(target=self._emulator_host)
)
self._instance_admin_api = InstanceAdminClient(
client_info=self._client_info,
Expand All @@ -258,9 +266,9 @@ def instance_admin_api(self):
def database_admin_api(self):
"""Helper for session-related API calls."""
if self._database_admin_api is None:
if _get_spanner_emulator_host() is not None:
if self._emulator_host is not None:
transport = database_admin_grpc_transport.DatabaseAdminGrpcTransport(
channel=grpc.insecure_channel(_get_spanner_emulator_host())
channel=grpc.insecure_channel(target=self._emulator_host)
)
self._database_admin_api = DatabaseAdminClient(
client_info=self._client_info,
Expand Down Expand Up @@ -363,7 +371,7 @@ def instance(
configuration_name,
node_count,
display_name,
_get_spanner_emulator_host(),
self._emulator_host,
)

def list_instances(self, filter_="", page_size=None, page_token=None):
Expand Down
4 changes: 1 addition & 3 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ def spanner_api(self):
channel=grpc.insecure_channel(self._instance.emulator_host)
)
self._spanner_api = SpannerClient(
client_info=client_info,
client_options=client_options,
transport=transport,
client_info=client_info, transport=transport
)
return self._spanner_api
credentials = self._instance._client.credentials
Expand Down
84 changes: 76 additions & 8 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,14 @@ def _constructor_test_helper(
@mock.patch("warnings.warn")
def test_constructor_emulator_host_warning(self, mock_warn, mock_em):
from google.cloud.spanner_v1 import client as MUT
from google.auth.credentials import AnonymousCredentials

expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,)
expected_scopes = None
creds = _make_credentials()
mock_em.return_value = "http://emulator.host.com"
self._constructor_test_helper(expected_scopes, creds)
with mock.patch("google.cloud.spanner_v1.client.AnonymousCredentials") as patch:
expected_creds = patch.return_value = AnonymousCredentials()
self._constructor_test_helper(expected_scopes, creds, expected_creds)
mock_warn.assert_called_once_with(MUT._EMULATOR_HOST_HTTP_SCHEME)

def test_constructor_default_scopes(self):
Expand Down Expand Up @@ -219,6 +222,8 @@ def test_constructor_custom_query_options_env_config(self, mock_ver):
def test_instance_admin_api(self, mock_em):
from google.cloud.spanner_v1.client import SPANNER_ADMIN_SCOPE

mock_em.return_value = None

credentials = _make_credentials()
client_info = mock.Mock()
client_options = mock.Mock()
Expand All @@ -230,7 +235,6 @@ def test_instance_admin_api(self, mock_em):
)
expected_scopes = (SPANNER_ADMIN_SCOPE,)

mock_em.return_value = None
inst_module = "google.cloud.spanner_v1.client.InstanceAdminClient"
with mock.patch(inst_module) as instance_admin_client:
api = client.instance_admin_api
Expand All @@ -250,7 +254,8 @@ def test_instance_admin_api(self, mock_em):
credentials.with_scopes.assert_called_once_with(expected_scopes)

@mock.patch("google.cloud.spanner_v1.client._get_spanner_emulator_host")
def test_instance_admin_api_emulator(self, mock_em):
def test_instance_admin_api_emulator_env(self, mock_em):
mock_em.return_value = "emulator.host"
credentials = _make_credentials()
client_info = mock.Mock()
client_options = mock.Mock()
Expand All @@ -261,7 +266,38 @@ def test_instance_admin_api_emulator(self, mock_em):
client_options=client_options,
)

mock_em.return_value = "true"
inst_module = "google.cloud.spanner_v1.client.InstanceAdminClient"
with mock.patch(inst_module) as instance_admin_client:
api = client.instance_admin_api

self.assertIs(api, instance_admin_client.return_value)

# API instance is cached
again = client.instance_admin_api
self.assertIs(again, api)

self.assertEqual(len(instance_admin_client.call_args_list), 1)
called_args, called_kw = instance_admin_client.call_args
self.assertEqual(called_args, ())
self.assertEqual(called_kw["client_info"], client_info)
self.assertEqual(called_kw["client_options"], client_options)
self.assertIn("transport", called_kw)
self.assertNotIn("credentials", called_kw)

def test_instance_admin_api_emulator_code(self):
from google.auth.credentials import AnonymousCredentials
from google.api_core.client_options import ClientOptions

credentials = AnonymousCredentials()
client_info = mock.Mock()
client_options = ClientOptions(api_endpoint="emulator.host")
client = self._make_one(
project=self.PROJECT,
credentials=credentials,
client_info=client_info,
client_options=client_options,
)

inst_module = "google.cloud.spanner_v1.client.InstanceAdminClient"
with mock.patch(inst_module) as instance_admin_client:
api = client.instance_admin_api
Expand All @@ -284,6 +320,7 @@ def test_instance_admin_api_emulator(self, mock_em):
def test_database_admin_api(self, mock_em):
from google.cloud.spanner_v1.client import SPANNER_ADMIN_SCOPE

mock_em.return_value = None
credentials = _make_credentials()
client_info = mock.Mock()
client_options = mock.Mock()
Expand All @@ -295,7 +332,6 @@ def test_database_admin_api(self, mock_em):
)
expected_scopes = (SPANNER_ADMIN_SCOPE,)

mock_em.return_value = None
db_module = "google.cloud.spanner_v1.client.DatabaseAdminClient"
with mock.patch(db_module) as database_admin_client:
api = client.database_admin_api
Expand All @@ -315,7 +351,8 @@ def test_database_admin_api(self, mock_em):
credentials.with_scopes.assert_called_once_with(expected_scopes)

@mock.patch("google.cloud.spanner_v1.client._get_spanner_emulator_host")
def test_database_admin_api_emulator(self, mock_em):
def test_database_admin_api_emulator_env(self, mock_em):
mock_em.return_value = "host:port"
credentials = _make_credentials()
client_info = mock.Mock()
client_options = mock.Mock()
Expand All @@ -326,7 +363,38 @@ def test_database_admin_api_emulator(self, mock_em):
client_options=client_options,
)

mock_em.return_value = "host:port"
db_module = "google.cloud.spanner_v1.client.DatabaseAdminClient"
with mock.patch(db_module) as database_admin_client:
api = client.database_admin_api

self.assertIs(api, database_admin_client.return_value)

# API instance is cached
again = client.database_admin_api
self.assertIs(again, api)

self.assertEqual(len(database_admin_client.call_args_list), 1)
called_args, called_kw = database_admin_client.call_args
self.assertEqual(called_args, ())
self.assertEqual(called_kw["client_info"], client_info)
self.assertEqual(called_kw["client_options"], client_options)
self.assertIn("transport", called_kw)
self.assertNotIn("credentials", called_kw)

def test_database_admin_api_emulator_code(self):
from google.auth.credentials import AnonymousCredentials
from google.api_core.client_options import ClientOptions

credentials = AnonymousCredentials()
client_info = mock.Mock()
client_options = ClientOptions(api_endpoint="emulator.host")
client = self._make_one(
project=self.PROJECT,
credentials=credentials,
client_info=client_info,
client_options=client_options,
)

db_module = "google.cloud.spanner_v1.client.DatabaseAdminClient"
with mock.patch(db_module) as database_admin_client:
api = client.database_admin_api
Expand Down

0 comments on commit b22630b

Please sign in to comment.