From d3dc88f0844bcb377a9e52312e1a99b5ca6e617e Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Mon, 1 Apr 2024 18:04:30 +0400 Subject: [PATCH] Avoid to use `functools.lru_cache` in class methods in `google` provider (#38652) --- .../google/cloud/hooks/compute_ssh.py | 2 +- .../google/common/hooks/base_google.py | 2 +- .../google/cloud/hooks/test_compute_ssh.py | 56 +++++++++++-------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/compute_ssh.py b/airflow/providers/google/cloud/hooks/compute_ssh.py index 2bc5dcf5142b9..97df5c5525061 100644 --- a/airflow/providers/google/cloud/hooks/compute_ssh.py +++ b/airflow/providers/google/cloud/hooks/compute_ssh.py @@ -334,7 +334,7 @@ def _authorize_compute_engine_instance_metadata(self, pubkey): ) def _authorize_os_login(self, pubkey): - username = self._oslogin_hook._get_credentials_email() + username = self._oslogin_hook._get_credentials_email self.log.info("Importing SSH public key using OSLogin: user=%s", username) expiration = int((time.time() + self.expire_time) * 1000000) ssh_public_key = {"key": pubkey, "expiration_time_usec": expiration} diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 13543243cb56b..ca08f86e78d7f 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -317,7 +317,7 @@ def _get_access_token(self) -> str: credentials.refresh(auth_req) return credentials.token - @functools.lru_cache(maxsize=None) + @functools.cached_property def _get_credentials_email(self) -> str: """ Return the email address associated with the currently logged in account. diff --git a/tests/providers/google/cloud/hooks/test_compute_ssh.py b/tests/providers/google/cloud/hooks/test_compute_ssh.py index 27cfe4fc1b26d..dfcd0d719c0a1 100644 --- a/tests/providers/google/cloud/hooks/test_compute_ssh.py +++ b/tests/providers/google/cloud/hooks/test_compute_ssh.py @@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook +from airflow.providers.google.cloud.hooks.os_login import OSLoginHook pytestmark = pytest.mark.db_test @@ -48,22 +49,35 @@ def test_delegate_to_runtime_error(self): with pytest.raises(RuntimeError): ComputeEngineSSHHook(gcp_conn_id="gcpssh", delegate_to="delegate_to") + def test_os_login_hook(self, mocker): + mock_os_login_hook = mocker.patch.object(OSLoginHook, "__init__", return_value=None, spec=OSLoginHook) + + # Default values + assert ComputeEngineSSHHook()._oslogin_hook + mock_os_login_hook.assert_called_with(gcp_conn_id="google_cloud_default") + + # Custom conn_id + assert ComputeEngineSSHHook(gcp_conn_id="gcpssh")._oslogin_hook + mock_os_login_hook.assert_called_with(gcp_conn_id="gcpssh") + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook") - @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient") - def test_get_conn_default_configuration( - self, mock_ssh_client, mock_paramiko, mock_os_login_hook, mock_compute_hook - ): - mock_paramiko.SSHException = Exception + def test_get_conn_default_configuration(self, mock_ssh_client, mock_paramiko, mock_compute_hook, mocker): + mock_paramiko.SSHException = RuntimeError mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME" mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ" mock_compute_hook.return_value.project_id = TEST_PROJECT_ID mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP - mock_os_login_hook.return_value._get_credentials_email.return_value = "test-example@example.org" - mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [ + mock_os_login_hook = mocker.patch.object( + ComputeEngineSSHHook, "_oslogin_hook", spec=OSLoginHook, name="FakeOsLoginHook" + ) + type(mock_os_login_hook)._get_credentials_email = mock.PropertyMock( + return_value="test-example@example.org" + ) + mock_os_login_hook.import_ssh_public_key.return_value.login_profile.posix_accounts = [ mock.MagicMock(username="test-username") ] @@ -83,16 +97,10 @@ def test_get_conn_default_configuration( ), ] ) - mock_os_login_hook.assert_has_calls( - [ - mock.call(gcp_conn_id="google_cloud_default"), - mock.call()._get_credentials_email(), - mock.call().import_ssh_public_key( - ssh_public_key={"key": "NAME AYZ root", "expiration_time_usec": mock.ANY}, - project_id="test-project-id", - user=mock_os_login_hook.return_value._get_credentials_email.return_value, - ), - ] + mock_os_login_hook.import_ssh_public_key.assert_called_once_with( + ssh_public_key={"key": "NAME AYZ root", "expiration_time_usec": mock.ANY}, + project_id="test-project-id", + user="test-example@example.org", ) mock_ssh_client.assert_has_calls( [ @@ -113,7 +121,6 @@ def test_get_conn_default_configuration( [(SSHException, r"Error occurred when establishing SSH connection using Paramiko")], ) @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook") - @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient") @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance") @@ -122,21 +129,26 @@ def test_get_conn_default_configuration_test_exceptions( mock_connect, mock_ssh_client, mock_paramiko, - mock_os_login_hook, mock_compute_hook, exception_type, error_message, caplog, + mocker, ): - mock_paramiko.SSHException = Exception + mock_paramiko.SSHException = RuntimeError mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME" mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ" mock_compute_hook.return_value.project_id = TEST_PROJECT_ID mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP - mock_os_login_hook.return_value._get_credentials_email.return_value = "test-example@example.org" - mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [ + mock_os_login_hook = mocker.patch.object( + ComputeEngineSSHHook, "_oslogin_hook", spec=OSLoginHook, name="FakeOsLoginHook" + ) + type(mock_os_login_hook)._get_credentials_email = mock.PropertyMock( + return_value="test-example@example.org" + ) + mock_os_login_hook.import_ssh_public_key.return_value.login_profile.posix_accounts = [ mock.MagicMock(username="test-username") ]