Skip to content

Commit

Permalink
fix: remove enable_iam_auth from downstream kwargs and catch error (#273
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jackwotherspoon authored Feb 11, 2022
1 parent cdfcc72 commit f9576f3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
11 changes: 9 additions & 2 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,18 @@ def connect(
# Use the InstanceConnectionManager to establish an SSL Connection.
#
# Return a DBAPI connection

enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
if instance_connection_string in self._instances:
icm = self._instances[instance_connection_string]
if enable_iam_auth != icm._enable_iam_auth:
raise ValueError(
"connect() called with `enable_iam_auth={}`, but previously used "
"enable_iam_auth={}`. If you require both for your use case, "
"please use a new connector.Connector object.".format(
enable_iam_auth, icm._enable_iam_auth
)
)
else:
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
icm = InstanceConnectionManager(
instance_connection_string,
driver,
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@
from typing import Any


class MockInstanceConnectionManager:
_enable_iam_auth: bool

def __init__(
self,
enable_iam_auth: bool = False,
) -> None:
self._enable_iam_auth = enable_iam_auth

def connect(
self,
driver: str,
ip_type: IPTypes,
timeout: int,
**kwargs: Any,
) -> Any:
return True


def test_connect_timeout(
fake_credentials: Credentials, async_loop: asyncio.AbstractEventLoop
) -> None:
Expand Down Expand Up @@ -59,6 +78,31 @@ async def timeout_stub(*args: Any, **kwargs: Any) -> None:
)


def test_connect_enable_iam_auth_error() -> None:
"""Test that calling connect() with different enable_iam_auth
argument values throws error."""
connect_string = "my-project:my-region:my-instance"
default_connector = connector.Connector()
with patch(
"google.cloud.sql.connector.connector.InstanceConnectionManager"
) as mock_icm:
mock_icm.return_value = MockInstanceConnectionManager(enable_iam_auth=False)
conn = default_connector.connect(
connect_string,
"pg8000",
enable_iam_auth=False,
)
assert conn is True
# try to connect using enable_iam_auth=True, should raise error
pytest.raises(
ValueError,
default_connector.connect,
connect_string,
"pg8000",
enable_iam_auth=True,
)


def test_default_Connector_Init() -> None:
"""Test that default Connector __init__ sets properties properly."""
default_connector = connector.Connector()
Expand Down

0 comments on commit f9576f3

Please sign in to comment.