From f9576f3b1b11e1cfbc71cc440a040799f6d7c267 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Fri, 11 Feb 2022 10:46:35 -0500 Subject: [PATCH] fix: remove enable_iam_auth from downstream kwargs and catch error (#273) --- google/cloud/sql/connector/connector.py | 11 +++++-- tests/unit/test_connector.py | 44 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 7f9f606e..b59ad39b 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -108,11 +108,18 @@ def connect( # Use the InstanceConnectionManager to establish an SSL Connection. # # Return a DBAPI connection - + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) if instance_connection_string in self._instances: icm = self._instances[instance_connection_string] + if enable_iam_auth != icm._enable_iam_auth: + raise ValueError( + "connect() called with `enable_iam_auth={}`, but previously used " + "enable_iam_auth={}`. If you require both for your use case, " + "please use a new connector.Connector object.".format( + enable_iam_auth, icm._enable_iam_auth + ) + ) else: - enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) icm = InstanceConnectionManager( instance_connection_string, driver, diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 519149da..a72ef5de 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -27,6 +27,25 @@ from typing import Any +class MockInstanceConnectionManager: + _enable_iam_auth: bool + + def __init__( + self, + enable_iam_auth: bool = False, + ) -> None: + self._enable_iam_auth = enable_iam_auth + + def connect( + self, + driver: str, + ip_type: IPTypes, + timeout: int, + **kwargs: Any, + ) -> Any: + return True + + def test_connect_timeout( fake_credentials: Credentials, async_loop: asyncio.AbstractEventLoop ) -> None: @@ -59,6 +78,31 @@ async def timeout_stub(*args: Any, **kwargs: Any) -> None: ) +def test_connect_enable_iam_auth_error() -> None: + """Test that calling connect() with different enable_iam_auth + argument values throws error.""" + connect_string = "my-project:my-region:my-instance" + default_connector = connector.Connector() + with patch( + "google.cloud.sql.connector.connector.InstanceConnectionManager" + ) as mock_icm: + mock_icm.return_value = MockInstanceConnectionManager(enable_iam_auth=False) + conn = default_connector.connect( + connect_string, + "pg8000", + enable_iam_auth=False, + ) + assert conn is True + # try to connect using enable_iam_auth=True, should raise error + pytest.raises( + ValueError, + default_connector.connect, + connect_string, + "pg8000", + enable_iam_auth=True, + ) + + def test_default_Connector_Init() -> None: """Test that default Connector __init__ sets properties properly.""" default_connector = connector.Connector()