Skip to content

Commit

Permalink
feat: add universe domain support to Connector (TPC) (#1045)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon authored Apr 10, 2024
1 parent a9a1d0a commit b1e9dee
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 16 deletions.
35 changes: 34 additions & 1 deletion google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
logger = logging.getLogger(name=__name__)

ASYNC_DRIVERS = ["asyncpg"]
_DEFAULT_SCHEME = "https://"
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"


class Connector:
Expand All @@ -58,6 +61,7 @@ def __init__(
quota_project: Optional[str] = None,
sqladmin_api_endpoint: Optional[str] = None,
user_agent: Optional[str] = None,
universe_domain: Optional[str] = None,
) -> None:
"""Initializes a Connector instance.
Expand Down Expand Up @@ -90,6 +94,10 @@ def __init__(
sqladmin_api_endpoint (str): Base URL to use when calling the Cloud SQL
Admin API endpoint. Defaults to "https://sqladmin.googleapis.com",
this argument should only be used in development.
universe_domain (str): The universe domain for Cloud SQL API calls.
Default: "googleapis.com".
"""
# if event loop is given, use for background tasks
if loop:
Expand Down Expand Up @@ -126,12 +134,36 @@ def __init__(
self._timeout = timeout
self._enable_iam_auth = enable_iam_auth
self._quota_project = quota_project
self._sqladmin_api_endpoint = sqladmin_api_endpoint
self._user_agent = user_agent
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes._from_str(ip_type)
self._ip_type = ip_type
self._universe_domain = universe_domain
# construct service endpoint for Cloud SQL Admin API calls
if not sqladmin_api_endpoint:
self._sqladmin_api_endpoint = (
_DEFAULT_SCHEME
+ _SQLADMIN_HOST_TEMPLATE.format(universe_domain=self.universe_domain)
)
# otherwise if endpoint override is passed in use it
else:
self._sqladmin_api_endpoint = sqladmin_api_endpoint

# validate that the universe domain of the credentials matches the
# universe domain of the service endpoint
if self._credentials.universe_domain != self.universe_domain:
raise ValueError(
f"The configured universe domain ({self.universe_domain}) does "
"not match the universe domain found in the credentials "
f"({self._credentials.universe_domain}). If you haven't "
"configured the universe domain explicitly, `googleapis.com` "
"is the default."
)

@property
def universe_domain(self) -> str:
return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN

def connect(
self, instance_connection_string: str, driver: str, **kwargs: Any
Expand Down Expand Up @@ -371,6 +403,7 @@ async def create_async_connector(
quota_project: Optional[str] = None,
sqladmin_api_endpoint: Optional[str] = None,
user_agent: Optional[str] = None,
universe_domain: Optional[str] = None,
) -> Connector:
"""Helper function to create Connector object for asyncio connections.
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from cryptography.x509.oid import NameOID
from google.auth.credentials import Credentials

from google.cloud.sql.connector.connector import _DEFAULT_UNIVERSE_DOMAIN
from google.cloud.sql.connector.instance import ConnectionInfo
from google.cloud.sql.connector.utils import generate_keys
from google.cloud.sql.connector.utils import write_to_file
Expand All @@ -41,6 +42,7 @@ def __init__(
) -> None:
self.token = token
self.expiry = expiry
self._universe_domain = _DEFAULT_UNIVERSE_DOMAIN

@property
def __class__(self) -> Credentials:
Expand Down Expand Up @@ -68,6 +70,11 @@ def expired(self) -> bool:
return False
return True

@property
def universe_domain(self) -> str:
"""The universe domain value."""
return self._universe_domain

@property
def valid(self) -> bool:
"""Checks the validity of the credentials.
Expand Down
15 changes: 0 additions & 15 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,6 @@ async def test_CloudSQLClient_init_(fake_credentials: FakeCredentials) -> None:
await client.close()


async def test_CloudSQLClient_init_default_service_endpoint(
fake_credentials: FakeCredentials,
) -> None:
"""
Test to check whether the __init__ method of CloudSQLClient
can correctly initialize the default service endpoint.
"""
driver = "pg8000"
client = CloudSQLClient(None, "my-quota-project", fake_credentials, driver=driver)
# verify base endpoint is set to proper default
assert client._sqladmin_api_endpoint == "https://sqladmin.googleapis.com"
# close client
await client.close()


@pytest.mark.asyncio
async def test_CloudSQLClient_init_custom_user_agent(
fake_credentials: FakeCredentials,
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,67 @@ def test_Connector_close_called_multiple_times(fake_credentials: Credentials) ->
assert connector._thread.is_alive() is False
# call connector.close a second time
connector.close()


def test_default_universe_domain(fake_credentials: Credentials) -> None:
"""Test that default universe domain and constructed service endpoint are
formatted correctly.
"""
with Connector(credentials=fake_credentials) as connector:
# test universe domain was not configured
assert connector._universe_domain is None
# test property and service endpoint construction
assert connector.universe_domain == "googleapis.com"
assert connector._sqladmin_api_endpoint == "https://sqladmin.googleapis.com"


def test_configured_universe_domain_matches_GDU(fake_credentials: Credentials) -> None:
"""Test that configured universe domain succeeds with matched GDU credentials."""
universe_domain = "googleapis.com"
with Connector(
credentials=fake_credentials, universe_domain=universe_domain
) as connector:
# test universe domain was configured
assert connector._universe_domain == universe_domain
# test property and service endpoint construction
assert connector.universe_domain == universe_domain
assert connector._sqladmin_api_endpoint == f"https://sqladmin.{universe_domain}"


def test_configured_universe_domain_matches_credentials(
fake_credentials: Credentials,
) -> None:
"""Test that configured universe domain succeeds with matching universe
domain credentials.
"""
universe_domain = "test-universe.test"
# set fake credentials to be configured for the universe domain
fake_credentials._universe_domain = universe_domain
with Connector(
credentials=fake_credentials, universe_domain=universe_domain
) as connector:
# test universe domain was configured
assert connector._universe_domain == universe_domain
# test property and service endpoint construction
assert connector.universe_domain == universe_domain
assert connector._sqladmin_api_endpoint == f"https://sqladmin.{universe_domain}"


def test_configured_universe_domain_mismatched_credentials(
fake_credentials: Credentials,
) -> None:
"""Test that configured universe domain errors with mismatched universe
domain credentials.
"""
universe_domain = "test-universe.test"
# credentials have GDU domain ("googleapis.com")
with pytest.raises(ValueError) as exc_info:
Connector(credentials=fake_credentials, universe_domain=universe_domain)
err_msg = (
f"The configured universe domain ({universe_domain}) does "
"not match the universe domain found in the credentials "
f"({fake_credentials.universe_domain}). If you haven't "
"configured the universe domain explicitly, `googleapis.com` "
"is the default."
)
assert exc_info.value.args[0] == err_msg

0 comments on commit b1e9dee

Please sign in to comment.