Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add ConnectionName to ConnectionInfo class #1212

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions google/cloud/sql/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.auth.transport import requests

from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
from google.cloud.sql.connector.refresh_utils import _downscope_credentials
from google.cloud.sql.connector.refresh_utils import retry_50x
Expand Down Expand Up @@ -245,20 +246,16 @@ async def _get_ephemeral(

async def get_connection_info(
self,
project: str,
region: str,
instance: str,
conn_name: ConnectionName,
keys: asyncio.Future,
enable_iam_auth: bool,
) -> ConnectionInfo:
"""Immediately performs a full refresh operation using the Cloud SQL
Admin API.

Args:
project (str): The name of the project the Cloud SQL instance is
located in.
region (str): The region the Cloud SQL instance is located in.
instance (str): Name of the Cloud SQL instance.
conn_name (ConnectionName): The Cloud SQL instance's
connection name.
keys (asyncio.Future): A future to the client's public-private key
pair.
enable_iam_auth (bool): Whether an automatic IAM database
Expand All @@ -278,16 +275,16 @@ async def get_connection_info(

metadata_task = asyncio.create_task(
self._get_metadata(
project,
region,
instance,
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
)

ephemeral_task = asyncio.create_task(
self._get_ephemeral(
project,
instance,
conn_name.project,
conn_name.instance_name,
pub_key,
enable_iam_auth,
)
Expand All @@ -311,6 +308,7 @@ async def get_connection_info(
ephemeral_cert, expiration = await ephemeral_task

return ConnectionInfo(
conn_name,
ephemeral_cert,
metadata["server_ca_cert"],
priv_key,
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from aiofiles.tempfile import TemporaryDirectory

from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
from google.cloud.sql.connector.exceptions import TLSVersionError
from google.cloud.sql.connector.utils import write_to_file
Expand All @@ -38,6 +39,7 @@ class ConnectionInfo:
"""Contains all necessary information to connect securely to the
server-side Proxy running on a Cloud SQL instance."""

conn_name: ConnectionName
client_cert: str
server_ca_cert: str
private_key: bytes
Expand Down
15 changes: 4 additions & 11 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def __init__(
name. To resolve a DNS record to an instance connection name, use
DnsResolver.
Default: DefaultResolver

"""
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
Expand Down Expand Up @@ -283,8 +282,7 @@ async def connect_async(
conn_name = await self._resolver.resolve(instance_connection_string)
if self._refresh_strategy == RefreshStrategy.LAZY:
logger.debug(
f"['{instance_connection_string}']: Refresh strategy is set"
" to lazy refresh"
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
)
cache = LazyRefreshCache(
conn_name,
Expand All @@ -294,18 +292,15 @@ async def connect_async(
)
else:
logger.debug(
f"['{instance_connection_string}']: Refresh strategy is set"
" to backgound refresh"
f"['{conn_name}']: Refresh strategy is set to backgound refresh"
)
cache = RefreshAheadCache(
conn_name,
self._client,
self._keys,
enable_iam_auth,
)
logger.debug(
f"['{instance_connection_string}']: Connection info added to cache"
)
logger.debug(f"['{conn_name}']: Connection info added to cache")
self._cache[(instance_connection_string, enable_iam_auth)] = cache

connect_func = {
Expand Down Expand Up @@ -344,9 +339,7 @@ async def connect_async(
# the cache and re-raise the error
await self._remove_cached(instance_connection_string, enable_iam_auth)
raise
logger.debug(
f"['{instance_connection_string}']: Connecting to {ip_address}:3307"
)
logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
# format `user` param for automatic IAM database authn
if enable_iam_auth:
formatted_user = format_database_user(
Expand Down
13 changes: 3 additions & 10 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ def __init__(
(Postgres and MySQL) as the default authentication method for all
connections.
"""
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
self._conn_name = conn_name

self._enable_iam_auth = enable_iam_auth
Expand Down Expand Up @@ -104,20 +99,18 @@ async def _perform_refresh(self) -> ConnectionInfo:
"""
self._refresh_in_progress.set()
logger.debug(
f"['{self._conn_name}']: Connection info refresh " "operation started"
f"['{self._conn_name}']: Connection info refresh operation started"
)

try:
await self._refresh_rate_limiter.acquire()
connection_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._conn_name,
self._keys,
self._enable_iam_auth,
)
logger.debug(
f"['{self._conn_name}']: Connection info " "refresh operation complete"
f"['{self._conn_name}']: Connection info refresh operation complete"
)
logger.debug(
f"['{self._conn_name}']: Current certificate "
Expand Down
10 changes: 1 addition & 9 deletions google/cloud/sql/connector/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,7 @@ def __init__(
(Postgres and MySQL) as the default authentication method for all
connections.
"""
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
self._conn_name = conn_name

self._enable_iam_auth = enable_iam_auth
self._keys = keys
self._client = client
Expand Down Expand Up @@ -101,9 +95,7 @@ async def connect_info(self) -> ConnectionInfo:
)
try:
conn_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._conn_name,
self._keys,
self._enable_iam_auth,
)
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/sql/connector/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import dns.asyncresolver

from google.cloud.sql.connector.connection_name import _parse_connection_name
from google.cloud.sql.connector.connection_name import (
_parse_connection_name_with_domain_name,
)
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import DnsResolutionError

Expand Down Expand Up @@ -52,7 +55,7 @@ async def query_dns(self, dns: str) -> ConnectionName:
# Attempt to parse records, returning the first valid record.
for record in rdata:
try:
conn_name = _parse_connection_name(record)
conn_name = _parse_connection_name_with_domain_name(record, dns)
return conn_name
except Exception:
continue
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/test_connection_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@

import pytest # noqa F401 Needed to run the tests

# fmt: off
from google.cloud.sql.connector.connection_name import _parse_connection_name
from google.cloud.sql.connector.connection_name import \
_parse_connection_name_with_domain_name
from google.cloud.sql.connector.connection_name import (
_parse_connection_name_with_domain_name,
)
from google.cloud.sql.connector.connection_name import ConnectionName

# fmt: on


def test_ConnectionName() -> None:
conn_name = ConnectionName("project", "region", "instance")
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ async def test_Instance_init(
can tell if the connection string that's passed in is formatted correctly.
"""
assert (
cache._project == "test-project"
and cache._region == "test-region"
and cache._instance == "test-instance"
cache._conn_name.project == "test-project"
and cache._conn_name.region == "test-region"
and cache._conn_name.instance_name == "test-instance"
)
assert cache._enable_iam_auth is False

Expand Down Expand Up @@ -283,7 +283,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None

async def test_ConnectionInfo_caches_sslcontext() -> None:
info = ConnectionInfo(
"cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now()
"", "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now()
)
# context should default to None
assert info.context is None
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

conn_str = "my-project:my-region:my-instance"
conn_name = ConnectionName("my-project", "my-region", "my-instance")
conn_name_with_domain = ConnectionName(
"my-project", "my-region", "my-instance", "db.example.com"
)


async def test_DefaultResolver() -> None:
Expand Down Expand Up @@ -74,7 +77,7 @@ async def test_DnsResolver_with_dns_name() -> None:
resolver.port = 5053
# Resolution should return first value sorted alphabetically
result = await resolver.resolve("db.example.com")
assert result == conn_name
assert result == conn_name_with_domain


query_text_malformed = """id 1234
Expand Down
Loading