Skip to content

Commit

Permalink
refactor: add ConnectionName to ConnectionInfo class (#1212)
Browse files Browse the repository at this point in the history
Adding ConnectionName to ConnectionInfo class

Benefits:
- Cleaner interface, passing connection name around instead of individual args (project, region, instance)
- Consistent debug logs (give access to ConnectionName throughout)
- Allows us to tell if ConnectionInfo is using DNS (if ConnectionName.domain_name is set)
  • Loading branch information
jackwotherspoon authored Dec 9, 2024
1 parent 5af7582 commit 7231c57
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 48 deletions.
22 changes: 10 additions & 12 deletions google/cloud/sql/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.auth.transport import requests

from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
from google.cloud.sql.connector.refresh_utils import _downscope_credentials
from google.cloud.sql.connector.refresh_utils import retry_50x
Expand Down Expand Up @@ -245,20 +246,16 @@ async def _get_ephemeral(

async def get_connection_info(
self,
project: str,
region: str,
instance: str,
conn_name: ConnectionName,
keys: asyncio.Future,
enable_iam_auth: bool,
) -> ConnectionInfo:
"""Immediately performs a full refresh operation using the Cloud SQL
Admin API.
Args:
project (str): The name of the project the Cloud SQL instance is
located in.
region (str): The region the Cloud SQL instance is located in.
instance (str): Name of the Cloud SQL instance.
conn_name (ConnectionName): The Cloud SQL instance's
connection name.
keys (asyncio.Future): A future to the client's public-private key
pair.
enable_iam_auth (bool): Whether an automatic IAM database
Expand All @@ -278,16 +275,16 @@ async def get_connection_info(

metadata_task = asyncio.create_task(
self._get_metadata(
project,
region,
instance,
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
)

ephemeral_task = asyncio.create_task(
self._get_ephemeral(
project,
instance,
conn_name.project,
conn_name.instance_name,
pub_key,
enable_iam_auth,
)
Expand All @@ -311,6 +308,7 @@ async def get_connection_info(
ephemeral_cert, expiration = await ephemeral_task

return ConnectionInfo(
conn_name,
ephemeral_cert,
metadata["server_ca_cert"],
priv_key,
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from aiofiles.tempfile import TemporaryDirectory

from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
from google.cloud.sql.connector.exceptions import TLSVersionError
from google.cloud.sql.connector.utils import write_to_file
Expand All @@ -38,6 +39,7 @@ class ConnectionInfo:
"""Contains all necessary information to connect securely to the
server-side Proxy running on a Cloud SQL instance."""

conn_name: ConnectionName
client_cert: str
server_ca_cert: str
private_key: bytes
Expand Down
15 changes: 4 additions & 11 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def __init__(
name. To resolve a DNS record to an instance connection name, use
DnsResolver.
Default: DefaultResolver
"""
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
Expand Down Expand Up @@ -283,8 +282,7 @@ async def connect_async(
conn_name = await self._resolver.resolve(instance_connection_string)
if self._refresh_strategy == RefreshStrategy.LAZY:
logger.debug(
f"['{instance_connection_string}']: Refresh strategy is set"
" to lazy refresh"
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
)
cache = LazyRefreshCache(
conn_name,
Expand All @@ -294,18 +292,15 @@ async def connect_async(
)
else:
logger.debug(
f"['{instance_connection_string}']: Refresh strategy is set"
" to backgound refresh"
f"['{conn_name}']: Refresh strategy is set to backgound refresh"
)
cache = RefreshAheadCache(
conn_name,
self._client,
self._keys,
enable_iam_auth,
)
logger.debug(
f"['{instance_connection_string}']: Connection info added to cache"
)
logger.debug(f"['{conn_name}']: Connection info added to cache")
self._cache[(instance_connection_string, enable_iam_auth)] = cache

connect_func = {
Expand Down Expand Up @@ -344,9 +339,7 @@ async def connect_async(
# the cache and re-raise the error
await self._remove_cached(instance_connection_string, enable_iam_auth)
raise
logger.debug(
f"['{instance_connection_string}']: Connecting to {ip_address}:3307"
)
logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
# format `user` param for automatic IAM database authn
if enable_iam_auth:
formatted_user = format_database_user(
Expand Down
13 changes: 3 additions & 10 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ def __init__(
(Postgres and MySQL) as the default authentication method for all
connections.
"""
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
self._conn_name = conn_name

self._enable_iam_auth = enable_iam_auth
Expand Down Expand Up @@ -104,20 +99,18 @@ async def _perform_refresh(self) -> ConnectionInfo:
"""
self._refresh_in_progress.set()
logger.debug(
f"['{self._conn_name}']: Connection info refresh " "operation started"
f"['{self._conn_name}']: Connection info refresh operation started"
)

try:
await self._refresh_rate_limiter.acquire()
connection_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._conn_name,
self._keys,
self._enable_iam_auth,
)
logger.debug(
f"['{self._conn_name}']: Connection info " "refresh operation complete"
f"['{self._conn_name}']: Connection info refresh operation complete"
)
logger.debug(
f"['{self._conn_name}']: Current certificate "
Expand Down
10 changes: 1 addition & 9 deletions google/cloud/sql/connector/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,7 @@ def __init__(
(Postgres and MySQL) as the default authentication method for all
connections.
"""
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
self._conn_name = conn_name

self._enable_iam_auth = enable_iam_auth
self._keys = keys
self._client = client
Expand Down Expand Up @@ -101,9 +95,7 @@ async def connect_info(self) -> ConnectionInfo:
)
try:
conn_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._conn_name,
self._keys,
self._enable_iam_auth,
)
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/sql/connector/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import dns.asyncresolver

from google.cloud.sql.connector.connection_name import (
_parse_connection_name_with_domain_name,
)
from google.cloud.sql.connector.connection_name import _parse_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import DnsResolutionError
Expand Down Expand Up @@ -52,7 +55,7 @@ async def query_dns(self, dns: str) -> ConnectionName:
# Attempt to parse records, returning the first valid record.
for record in rdata:
try:
conn_name = _parse_connection_name(record)
conn_name = _parse_connection_name_with_domain_name(record, dns)
return conn_name
except Exception:
continue
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ async def test_Instance_init(
can tell if the connection string that's passed in is formatted correctly.
"""
assert (
cache._project == "test-project"
and cache._region == "test-region"
and cache._instance == "test-instance"
cache._conn_name.project == "test-project"
and cache._conn_name.region == "test-region"
and cache._conn_name.instance_name == "test-instance"
)
assert cache._enable_iam_auth is False

Expand Down Expand Up @@ -283,7 +283,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None

async def test_ConnectionInfo_caches_sslcontext() -> None:
info = ConnectionInfo(
"cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now()
"", "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now()
)
# context should default to None
assert info.context is None
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

conn_str = "my-project:my-region:my-instance"
conn_name = ConnectionName("my-project", "my-region", "my-instance")
conn_name_with_domain = ConnectionName(
"my-project", "my-region", "my-instance", "db.example.com"
)


async def test_DefaultResolver() -> None:
Expand Down Expand Up @@ -74,7 +77,7 @@ async def test_DnsResolver_with_dns_name() -> None:
resolver.port = 5053
# Resolution should return first value sorted alphabetically
result = await resolver.resolve("db.example.com")
assert result == conn_name
assert result == conn_name_with_domain


query_text_malformed = """id 1234
Expand Down

0 comments on commit 7231c57

Please sign in to comment.