From 3c7a32356e762bf47ecc7aeea588990db78f19b6 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 8 Jul 2021 17:39:18 -0700 Subject: [PATCH 01/14] Add context manager API to credentials --- .../identity/_credentials/app_service.py | 10 +++++++++ .../_credentials/authorization_code.py | 12 +++++++++++ .../azure/identity/_credentials/azure_arc.py | 10 +++++++++ .../azure/identity/_credentials/azure_cli.py | 10 +++++++++ .../identity/_credentials/azure_powershell.py | 10 +++++++++ .../azure/identity/_credentials/chained.py | 14 +++++++++++++ .../identity/_credentials/cloud_shell.py | 10 +++++++++ .../identity/_credentials/environment.py | 14 +++++++++++++ .../azure/identity/_credentials/imds.py | 10 +++++++++ .../identity/_credentials/managed_identity.py | 12 +++++++++++ .../identity/_credentials/service_fabric.py | 10 +++++++++ .../identity/_credentials/shared_cache.py | 21 +++++++++++++++++++ .../azure/identity/_credentials/silent.py | 7 +++++++ .../azure/identity/_credentials/vscode.py | 14 +++++++++++++ .../azure/identity/_internal/aad_client.py | 11 ++++++++++ .../_internal/managed_identity_client.py | 11 ++++++++++ .../azure/identity/_internal/msal_client.py | 11 ++++++++++ .../identity/_internal/msal_credentials.py | 11 ++++++++++ 18 files changed, 208 insertions(+) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py index 4c2663cd5d88..b01eed9c8676 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py @@ -30,6 +30,16 @@ def __init__(self, **kwargs): else: self._available = False + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + self.__exit__() + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken if not self._available: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py index ce7b33f117f2..587547640744 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -44,6 +44,18 @@ def __init__(self, tenant_id, client_id, authorization_code, redirect_uri, **kwa self._redirect_uri = redirect_uri super(AuthorizationCodeCredential, self).__init__() + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for `scopes`. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py index 5fcc5ddfbe9a..7b728c722f6f 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py @@ -40,6 +40,16 @@ def __init__(self, **kwargs): **kwargs ) + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + self.__exit__() + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken if not self._available: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py index 111856a3a090..24e160d768c4 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py @@ -44,6 +44,16 @@ class AzureCliCredential(object): def __init__(self, **kwargs): self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False) + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def close(self): + # type: () -> None + """Calling this method is unnecessary.""" + @log_get_token("AzureCliCredential") def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py index ed3befb14462..8ce657c6e261 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py @@ -61,6 +61,16 @@ def __init__(self, **kwargs): # type: (**Any) -> None self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False) + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def close(self): + # type: () -> None + """Calling this method is unnecessary.""" + @log_get_token("AzurePowerShellCredential") def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py index 239adf55d988..35936acb7679 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py @@ -53,6 +53,20 @@ def __init__(self, *credentials): self._successful_credential = None # type: Optional[TokenCredential] self.credentials = credentials + def __enter__(self): + for credential in self.credentials: + credential.__enter__() + return self + + def __exit__(self, *args): + for credential in self.credentials: + credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the transport session of each credential in the chain.""" + self.__exit__() + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (*str, **Any) -> AccessToken """Request a token from each chained credential, in order, returning the first token received. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py index 17e10feec6d9..ac929f02d97c 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py @@ -33,6 +33,16 @@ def __init__(self, **kwargs): else: self._available = False + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + self.__exit__() + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken if not self._available: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py index 97aecf8c03f1..51e77d160372 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py @@ -101,6 +101,20 @@ def __init__(self, **kwargs): else: _LOGGER.info("No environment configuration found.") + def __enter__(self): + if self._credential: + self._credential.__enter__() + return self + + def __exit__(self, *args): + if self._credential: + self._credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + @log_get_token("EnvironmentCredential") def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index ca385eeecb7e..d642dd063b21 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -56,6 +56,16 @@ def __init__(self, **kwargs): self._error_message = None # type: Optional[str] self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + self.__exit__() + def _acquire_token_silently(self, *scopes, **kwargs): # type: (*str, **Any) -> Optional[AccessToken] return self._client.get_cached_token(*scopes) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index d0a0acef7931..3a291c1e1df3 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -81,6 +81,18 @@ def __init__(self, **kwargs): _LOGGER.info("%s will use IMDS", self.__class__.__name__) self._credential = ImdsCredential(**kwargs) + def __enter__(self): + self._credential.__enter__() + return self + + def __exit__(self, *args): + self._credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + @log_get_token("ManagedIdentityCredential") def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py index 0594a06f84ac..b1f4359424c6 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py @@ -30,6 +30,16 @@ def __init__(self, **kwargs): else: self._available = False + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + self.__exit__() + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken if not self._available: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 7d49c9dcac3f..906aaad174fc 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -48,6 +48,18 @@ def __init__(self, username=None, **kwargs): else: self._credential = _SharedTokenCacheCredential(username=username, **kwargs) + def __enter__(self): + self._credential.__enter__() + return self + + def __exit__(self, *args): + self._credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + @log_get_token("SharedTokenCacheCredential") def get_token(self, *scopes, **kwargs): # type (*str, **Any) -> AccessToken @@ -84,6 +96,15 @@ def supported(): class _SharedTokenCacheCredential(SharedTokenCacheBase): """The original SharedTokenCacheCredential, which doesn't use msal.ClientApplication""" + def __enter__(self): + if self._client: + self._client.__enter__() + return self + + def __exit__(self, *args): + if self._client: + self._client.__exit__(*args) + def get_token(self, *scopes, **kwargs): # type (*str, **Any) -> AccessToken if not scopes: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py index e6996085c199..b1aa1ade8c46 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py @@ -41,6 +41,13 @@ def __init__(self, authentication_record, **kwargs): self._client = MsalClient(**kwargs) self._initialized = False + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type (*str, **Any) -> AccessToken if not scopes: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py index 904766782bf7..cd6866f319da 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py @@ -125,6 +125,20 @@ class VisualStudioCodeCredential(_VSCodeCredentialBase, GetTokenMixin): user's home tenant or the tenant configured by **tenant_id** or VS Code's user settings. """ + def __enter__(self): + if self._client: + self._client.__enter__() + return self + + def __exit__(self, *args): + if self._client: + self._client.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for `scopes` as the user currently signed in to Visual Studio Code. diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index 79f986fca8a5..ceffb908922d 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -17,6 +17,17 @@ class AadClient(AadClientBase): + def __enter__(self): + self._pipeline.__enter__() + return self + + def __exit__(self, *args): + self._pipeline.__exit__() + + def close(self): + # type: () -> None + self.__exit__() + def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs): # type: (Iterable[str], str, str, Optional[str], **Any) -> AccessToken request = self._get_auth_code_request( diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py index b593f7394076..2b2164f9a382 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py @@ -104,6 +104,17 @@ def _build_pipeline(self, **kwargs): class ManagedIdentityClient(ManagedIdentityClientBase): + def __enter__(self): + self._pipeline.__enter__() + return self + + def __exit__(self, *args): + self._pipeline.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + def request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken resource = _scopes_to_resource(*scopes) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py index 87fdcde5a0e4..faec84d439d7 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py @@ -73,6 +73,17 @@ def __init__(self, **kwargs): # pylint:disable=missing-client-constructor-param self._local = threading.local() self._pipeline = build_pipeline(**kwargs) + def __enter__(self): + self._pipeline.__enter__() + return self + + def __exit__(self, *args): + self._pipeline.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + def post(self, url, params=None, data=None, headers=None, **kwargs): # pylint:disable=unused-argument # type: (str, Optional[Dict[str, str]], RequestData, Optional[Dict[str, str]], **Any) -> MsalResponse request = HttpRequest("POST", url, headers=headers) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index 3732f1595313..8ac10bbd687d 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -50,6 +50,17 @@ def __init__(self, client_id, client_credential=None, **kwargs): super(MsalCredential, self).__init__() + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + def _get_app(self, **kwargs): # type: (**Any) -> msal.ClientApplication tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs) From 157be95b61e9271a227c095909fba9aeb5613e9c Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 9 Jul 2021 10:31:54 -0700 Subject: [PATCH 02/14] factor base class out of managed identity credentials --- .../identity/_credentials/app_service.py | 48 +++----------- .../azure/identity/_credentials/azure_arc.py | 41 ++++-------- .../identity/_credentials/cloud_shell.py | 49 +++------------ .../identity/_credentials/service_fabric.py | 46 +++----------- .../_internal/managed_identity_base.py | 62 +++++++++++++++++++ .../identity/aio/_credentials/app_service.py | 41 +++--------- .../identity/aio/_credentials/azure_arc.py | 40 +++--------- .../identity/aio/_credentials/cloud_shell.py | 42 +++---------- .../aio/_credentials/service_fabric.py | 40 +++--------- .../aio/_internal/managed_identity_base.py | 56 +++++++++++++++++ 10 files changed, 189 insertions(+), 276 deletions(-) create mode 100644 sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py create mode 100644 sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py index b01eed9c8676..7da9589097d3 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py @@ -6,55 +6,27 @@ import os from typing import TYPE_CHECKING -from azure.core.credentials import AccessToken from azure.core.pipeline.transport import HttpRequest -from .. import CredentialUnavailableError from .._constants import EnvironmentVariables +from .._internal.managed_identity_base import ManagedIdentityBase from .._internal.managed_identity_client import ManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin if TYPE_CHECKING: from typing import Any, Optional -class AppServiceCredential(GetTokenMixin): - def __init__(self, **kwargs): - # type: (**Any) -> None - super(AppServiceCredential, self).__init__() - +class AppServiceCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] client_args = _get_client_args(**kwargs) if client_args: - self._available = True - self._client = ManagedIdentityClient(**client_args) - else: - self._available = False - - def __enter__(self): - self._client.__enter__() - return self - - def __exit__(self, *args): - self._client.__exit__(*args) - - def close(self): - self.__exit__() - - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - if not self._available: - raise CredentialUnavailableError( - message="App Service managed identity configuration not found in environment" - ) - return super(AppServiceCredential, self).get_token(*scopes, **kwargs) - - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) - - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - return self._client.request_token(*scopes, **kwargs) + return ManagedIdentityClient(**client_args) + return None + + def get_unavailable_message(self): + # type: () -> str + return "App Service managed identity configuration not found in environment" def _get_client_args(**kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py index 7b728c722f6f..ff38ddbc7a83 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py @@ -10,35 +10,28 @@ from azure.core.pipeline.transport import HttpRequest from azure.core.pipeline.policies import HTTPPolicy -from .. import CredentialUnavailableError from .._constants import EnvironmentVariables +from .._internal.managed_identity_base import ManagedIdentityBase from .._internal.managed_identity_client import ManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Optional, Union - from azure.core.credentials import AccessToken + from typing import Any, Optional from azure.core.pipeline import PipelineRequest, PipelineResponse - from azure.core.pipeline.policies import SansIOHTTPPolicy - PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy] - - -class AzureArcCredential(GetTokenMixin): - def __init__(self, **kwargs): - # type: (**Any) -> None - super(AzureArcCredential, self).__init__() +class AzureArcCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT) imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT) - self._available = url and imds - if self._available: - self._client = ManagedIdentityClient( + if url and imds: + return ManagedIdentityClient( _per_retry_policies=[ArcChallengeAuthPolicy()], request_factory=functools.partial(_get_request, url), **kwargs ) + return None def __enter__(self): self._client.__enter__() @@ -50,21 +43,9 @@ def __exit__(self, *args): def close(self): self.__exit__() - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - if not self._available: - raise CredentialUnavailableError( - message="Azure Arc managed identity configuration not found in environment" - ) - return super(AzureArcCredential, self).get_token(*scopes, **kwargs) - - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) - - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - return self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self): + # type: () -> str + return "Azure Arc managed identity configuration not found in environment" def _get_request(url, scope, identity_config): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py index ac929f02d97c..a9fb5ed96432 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py @@ -8,56 +8,27 @@ from azure.core.pipeline.transport import HttpRequest -from .. import CredentialUnavailableError from .._constants import EnvironmentVariables from .._internal.managed_identity_client import ManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin +from .._internal.managed_identity_base import ManagedIdentityBase if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class CloudShellCredential(GetTokenMixin): - def __init__(self, **kwargs): - # type: (**Any) -> None - super(CloudShellCredential, self).__init__() +class CloudShellCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT) if url: - self._available = True - self._client = ManagedIdentityClient( - request_factory=functools.partial(_get_request, url), - base_headers={"Metadata": "true"}, - **kwargs + return ManagedIdentityClient( + request_factory=functools.partial(_get_request, url), base_headers={"Metadata": "true"}, **kwargs ) - else: - self._available = False + return None - def __enter__(self): - self._client.__enter__() - return self - - def __exit__(self, *args): - self._client.__exit__(*args) - - def close(self): - self.__exit__() - - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - if not self._available: - raise CredentialUnavailableError( - message="Cloud Shell managed identity configuration not found in environment" - ) - return super(CloudShellCredential, self).get_token(*scopes, **kwargs) - - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) - - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - return self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self): + # type: () -> str + return "Cloud Shell managed identity configuration not found in environment" def _get_request(url, scope, identity_config): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py index b1f4359424c6..cb8c29a84ccc 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py @@ -8,53 +8,25 @@ from azure.core.pipeline.transport import HttpRequest -from .. import CredentialUnavailableError from .._constants import EnvironmentVariables +from .._internal.managed_identity_base import ManagedIdentityBase from .._internal.managed_identity_client import ManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class ServiceFabricCredential(GetTokenMixin): - def __init__(self, **kwargs): - # type: (**Any) -> None - super(ServiceFabricCredential, self).__init__() - +class ServiceFabricCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] client_args = _get_client_args(**kwargs) if client_args: - self._available = True - self._client = ManagedIdentityClient(**client_args) - else: - self._available = False - - def __enter__(self): - self._client.__enter__() - return self - - def __exit__(self, *args): - self._client.__exit__(*args) - - def close(self): - self.__exit__() - - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - if not self._available: - raise CredentialUnavailableError( - message="Service Fabric managed identity configuration not found in environment" - ) - return super(ServiceFabricCredential, self).get_token(*scopes, **kwargs) - - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) + return ManagedIdentityClient(**client_args) + return None - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - return self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self): + # type: () -> str + return "Service Fabric managed identity configuration not found in environment" def _get_client_args(**kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py new file mode 100644 index 000000000000..fba0c90e4071 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py @@ -0,0 +1,62 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import abc +from typing import cast, TYPE_CHECKING + +from .. import CredentialUnavailableError +from .._internal.managed_identity_client import ManagedIdentityClient +from .._internal.get_token_mixin import GetTokenMixin + +if TYPE_CHECKING: + from typing import Any, Optional + from azure.core.credentials import AccessToken + + +class ManagedIdentityBase(GetTokenMixin): + """Base class for internal credentials using ManagedIdentityClient""" + + def __init__(self, **kwargs): + # type: (**Any) -> None + super(ManagedIdentityBase, self).__init__() + self._client = self.get_client(**kwargs) + + @abc.abstractmethod + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] + pass + + @abc.abstractmethod + def get_unavailable_message(self): + # type: () -> str + pass + + def __enter__(self): + if self._client: + self._client.__enter__() + return self + + def __exit__(self, *args): + if self._client: + self._client.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + + def get_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return super(ManagedIdentityBase, self).get_token(*scopes, **kwargs) + + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] + # casting because mypy can't determine that these methods are called + # only by get_token, which raises when self._client is None + return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes) + + def _request_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py index e09f35006c1e..b587a4ab6a43 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py @@ -4,46 +4,21 @@ # ------------------------------------ from typing import TYPE_CHECKING -from .._internal import AsyncContextManager +from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin -from ... import CredentialUnavailableError from ..._credentials.app_service import _get_client_args if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class AppServiceCredential(AsyncContextManager, GetTokenMixin): - def __init__(self, **kwargs: "Any") -> None: - super(AppServiceCredential, self).__init__() - +class AppServiceCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": client_args = _get_client_args(**kwargs) if client_args: - self._available = True - self._client = AsyncManagedIdentityClient(**client_args) - else: - self._available = False - - async def __aenter__(self): - if self._available: - await self._client.__aenter__() - return self - - async def close(self) -> None: - await self._client.close() - - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - if not self._available: - raise CredentialUnavailableError( - message="App Service managed identity configuration not found in environment" - ) - - return await super().get_token(*scopes, **kwargs) - - async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) + return AsyncManagedIdentityClient(**client_args) + return None - async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - return await self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self): + # type: () -> str + return "App Service managed identity configuration not found in environment" diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py index 0b75b63af44a..babc5fb28f63 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py @@ -8,55 +8,31 @@ from azure.core.pipeline.policies import AsyncHTTPPolicy -from .._internal import AsyncContextManager +from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin -from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables from ..._credentials.azure_arc import _get_request, _get_secret_key if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports from typing import Any, Optional - from azure.core.credentials import AccessToken from azure.core.pipeline import PipelineRequest, PipelineResponse -class AzureArcCredential(AsyncContextManager, GetTokenMixin): - def __init__(self, **kwargs: "Any") -> None: - super().__init__() - +class AzureArcCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT) imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT) - self._available = url and imds - if self._available: - self._client = AsyncManagedIdentityClient( + if url and imds: + return AsyncManagedIdentityClient( _per_retry_policies=[ArcChallengeAuthPolicy()], request_factory=functools.partial(_get_request, url), **kwargs ) + return None - async def __aenter__(self): - if self._available: - await self._client.__aenter__() - return self - - async def close(self) -> None: - await self._client.close() - - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - if not self._available: - raise CredentialUnavailableError( - message="Service Fabric managed identity configuration not found in environment" - ) - - return await super().get_token(*scopes, **kwargs) - - async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) - - async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - return await self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self) -> str: + return "Service Fabric managed identity configuration not found in environment" class ArcChallengeAuthPolicy(AsyncHTTPPolicy): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py index 512fc621a21b..140d9afe8bbf 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py @@ -6,49 +6,23 @@ import os from typing import TYPE_CHECKING -from .._internal import AsyncContextManager -from .._internal.get_token_mixin import GetTokenMixin +from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient -from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables from ..._credentials.cloud_shell import _get_request if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class CloudShellCredential(AsyncContextManager, GetTokenMixin): - def __init__(self, **kwargs: "Any") -> None: - super(CloudShellCredential, self).__init__() +class CloudShellCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT) if url: - self._available = True - self._client = AsyncManagedIdentityClient( - request_factory=functools.partial(_get_request, url), - base_headers={"Metadata": "true"}, - **kwargs, + return AsyncManagedIdentityClient( + request_factory=functools.partial(_get_request, url), base_headers={"Metadata": "true"}, **kwargs ) - else: - self._available = False + return None - async def __aenter__(self): - if self._available: - await self._client.__aenter__() - return self - - async def close(self) -> None: - await self._client.close() - - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - if not self._available: - raise CredentialUnavailableError( - message="Cloud Shell managed identity configuration not found in environment" - ) - return await super(CloudShellCredential, self).get_token(*scopes, **kwargs) - - async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) - - async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - return await self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self) -> str: + return "Cloud Shell managed identity configuration not found in environment" diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py index 5e8de07d763e..1ab3477bab90 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py @@ -4,46 +4,20 @@ # ------------------------------------ from typing import TYPE_CHECKING -from .._internal import AsyncContextManager +from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin -from ... import CredentialUnavailableError from ..._credentials.service_fabric import _get_client_args if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class ServiceFabricCredential(AsyncContextManager, GetTokenMixin): - def __init__(self, **kwargs: "Any") -> None: - super().__init__() - +class ServiceFabricCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": client_args = _get_client_args(**kwargs) if client_args: - self._available = True - self._client = AsyncManagedIdentityClient(**client_args) - else: - self._available = False - - async def __aenter__(self): - if self._available: - await self._client.__aenter__() - return self - - async def close(self) -> None: - await self._client.close() - - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - if not self._available: - raise CredentialUnavailableError( - message="Service Fabric managed identity configuration not found in environment" - ) - - return await super().get_token(*scopes, **kwargs) - - async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) + return AsyncManagedIdentityClient(**client_args) + return None - async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - return await self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self) -> str: + return "Service Fabric managed identity configuration not found in environment" diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py new file mode 100644 index 000000000000..63d8b6db9b49 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py @@ -0,0 +1,56 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import abc +from typing import cast, TYPE_CHECKING + +from . import AsyncContextManager +from .get_token_mixin import GetTokenMixin +from .managed_identity_client import AsyncManagedIdentityClient +from ... import CredentialUnavailableError + +if TYPE_CHECKING: + from typing import Any, Optional + from azure.core.credentials import AccessToken + + +class AsyncManagedIdentityBase(AsyncContextManager, GetTokenMixin): + """Base class for internal credentials using AsyncManagedIdentityClient""" + + def __init__(self, **kwargs: "Any") -> None: + super().__init__() + self._client = self.get_client(**kwargs) + + @abc.abstractmethod + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": + pass + + @abc.abstractmethod + def get_unavailable_message(self) -> str: + pass + + async def __aenter__(self): + if self._client: + await self._client.__aenter__() + return self + + async def __aexit__(self, *args): + if self._client: + await self._client.__aexit__(*args) + + async def close(self) -> None: + await self.__aexit__() + + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return await super().get_token(*scopes, **kwargs) + + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + # casting because mypy can't determine that these methods are called + # only by get_token, which raises when self._client is None + return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes) + + async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs) From ae4e1b45f543bbf0db8b2681e55aeb4213371b2b Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 9 Jul 2021 14:18:48 -0700 Subject: [PATCH 03/14] changelog --- sdk/identity/azure-identity/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index c5d81fdaf9b9..20841172a655 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -29,6 +29,10 @@ workarounds when importing transitive dependencies such as pywin32 fails ([#19989](https://github.com/Azure/azure-sdk-for-python/issues/19989)) +- Added context manager methods and `close()` to credentials in the + `azure.identity` namespace. At the end of a `with` block, or when `close()` + is called, these credentials close their underlying transport sessions. + ([#18798](https://github.com/Azure/azure-sdk-for-python/issues/18798)) ## 1.7.0b2 (2021-07-08) From d8864269807f241551895b8b89d4940a0edff32e Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 19 Jul 2021 08:35:20 -0700 Subject: [PATCH 04/14] Thanks, McCoy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: McCoy Patiño <39780829+mccoyp@users.noreply.github.com> --- .../azure/identity/aio/_credentials/app_service.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py index b587a4ab6a43..b41274042b11 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py @@ -19,6 +19,5 @@ def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": return AsyncManagedIdentityClient(**client_args) return None - def get_unavailable_message(self): - # type: () -> str + def get_unavailable_message(self) -> str: return "App Service managed identity configuration not found in environment" From 6d37c70397b6a8e237d0bce367986de2b421635a Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 19 Jul 2021 09:05:00 -0700 Subject: [PATCH 05/14] test VisualStudioCodeCredential custom cloud configuration --- .../azure-identity/tests/test_vscode_credential.py | 9 +++++++++ .../tests/test_vscode_credential_async.py | 10 ++++++++++ 2 files changed, 19 insertions(+) diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index ec06e41d14b4..e6db05f56e5f 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -225,6 +225,15 @@ def test_adfs(): assert "adfs" in ex.value.message.lower() +def test_custom_cloud_no_authority(): + """The credential is unavailable when VS Code is configured to use a custom cloud with no known authority""" + + cloud_name = "AzureCustomCloud" + credential = get_credential({"azure.cloud": cloud_name}) + with pytest.raises(CredentialUnavailableError, match="authority.*" + cloud_name): + credential.get_token("scope") + + @pytest.mark.parametrize( "cloud,authority", ( diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py index d91332be0320..40c996e39957 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py @@ -213,6 +213,16 @@ async def test_adfs(): assert "adfs" in ex.value.message.lower() +@pytest.mark.asyncio +async def test_custom_cloud_no_authority(): + """The credential is unavailable when VS Code is configured to use a cloud with no known authority""" + + cloud_name = "AzureCustomCloud" + credential = get_credential({"azure.cloud": cloud_name}) + with pytest.raises(CredentialUnavailableError, match="authority.*" + cloud_name): + await credential.get_token("scope") + + @pytest.mark.asyncio @pytest.mark.parametrize( "cloud,authority", From 3246dbe92cf0435f2db0837c8893844b222d9d11 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 28 Jul 2021 10:32:30 -0700 Subject: [PATCH 06/14] AadClient.__exit__ passes args down --- .../azure-identity/azure/identity/_internal/aad_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index ceffb908922d..fb27461d8eb4 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -22,7 +22,7 @@ def __enter__(self): return self def __exit__(self, *args): - self._pipeline.__exit__() + self._pipeline.__exit__(*args) def close(self): # type: () -> None From bad9497841983387eec0294102cc08eefcbfa147 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 9 Jul 2021 14:15:47 -0700 Subject: [PATCH 07/14] tests --- .../tests/test_chained_credential.py | 35 +++++- .../tests/test_context_manager.py | 106 ++++++++++++++++++ .../azure-identity/tests/test_default.py | 27 ++++- .../tests/test_environment_credential.py | 11 +- .../test_environment_credential_async.py | 38 ++++++- .../tests/test_managed_identity.py | 50 +++++++++ .../tests/test_managed_identity_async.py | 13 ++- .../tests/test_shared_cache_credential.py | 35 +++++- 8 files changed, 306 insertions(+), 9 deletions(-) create mode 100644 sdk/identity/azure-identity/tests/test_context_manager.py diff --git a/sdk/identity/azure-identity/tests/test_chained_credential.py b/sdk/identity/azure-identity/tests/test_chained_credential.py index 85ce50b5035a..db6da5dc80c7 100644 --- a/sdk/identity/azure-identity/tests/test_chained_credential.py +++ b/sdk/identity/azure-identity/tests/test_chained_credential.py @@ -3,9 +3,9 @@ # Licensed under the MIT License. # ------------------------------------ try: - from unittest.mock import Mock + from unittest.mock import MagicMock, Mock except ImportError: # python < 3.3 - from mock import Mock # type: ignore + from mock import MagicMock, Mock # type: ignore from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError @@ -17,6 +17,37 @@ import pytest +def test_close(): + credentials = [MagicMock(close=Mock()) for _ in range(5)] + chain = ChainedTokenCredential(*credentials) + + for credential in credentials: + assert credential.__exit__.call_count == 0 + + chain.close() + + for credential in credentials: + assert credential.__exit__.call_count == 1 + + +def test_context_manager(): + credentials = [MagicMock() for _ in range(5)] + chain = ChainedTokenCredential(*credentials) + + for credential in credentials: + assert credential.__enter__.call_count == 0 + assert credential.__exit__.call_count == 0 + + with chain: + for credential in credentials: + assert credential.__enter__.call_count == 1 + assert credential.__exit__.call_count == 0 + + for credential in credentials: + assert credential.__enter__.call_count == 1 + assert credential.__exit__.call_count == 1 + + def test_error_message(): first_error = "first_error" first_credential = Mock( diff --git a/sdk/identity/azure-identity/tests/test_context_manager.py b/sdk/identity/azure-identity/tests/test_context_manager.py new file mode 100644 index 000000000000..295f657a6bcc --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_context_manager.py @@ -0,0 +1,106 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +try: + from unittest.mock import MagicMock, patch +except ImportError: + from mock import MagicMock, patch # type: ignore + +from azure.identity import ( + AzureApplicationCredential, + AzureCliCredential, + AzurePowerShellCredential, + AuthorizationCodeCredential, + CertificateCredential, + ClientSecretCredential, + DeviceCodeCredential, + EnvironmentCredential, + InteractiveBrowserCredential, + SharedTokenCacheCredential, + UsernamePasswordCredential, + VisualStudioCodeCredential, +) +from azure.identity._constants import EnvironmentVariables + +import pytest + +from test_certificate_credential import CERT_PATH +from test_vscode_credential import GET_USER_SETTINGS + + +class CredentialFixture: + def __init__(self, cls, default_kwargs=None, ctor_patch=None, get_token_patch=None): + self.cls = cls + self.get_token_patch = get_token_patch or MagicMock() + self._default_kwargs = default_kwargs or {} + self._ctor_patch = ctor_patch or MagicMock() + + def __call__(self, **kwargs): + with self._ctor_patch: + return self.cls(**dict(self._default_kwargs, **kwargs)) + + +FIXTURES = ( + CredentialFixture( + AuthorizationCodeCredential, + {kwarg: "..." for kwarg in ("tenant_id", "client_id", "authorization_code", "redirect_uri")}, + ), + CredentialFixture(CertificateCredential, {"tenant_id": "...", "client_id": "...", "certificate_path": CERT_PATH}), + CredentialFixture(ClientSecretCredential, {kwarg: "..." for kwarg in ("tenant_id", "client_id", "client_secret")}), + CredentialFixture(DeviceCodeCredential), + CredentialFixture( + EnvironmentCredential, + ctor_patch=patch.dict("os.environ", {var: ".." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True), + ), + CredentialFixture(InteractiveBrowserCredential), + CredentialFixture(UsernamePasswordCredential, {"client_id": "...", "username": "...", "password": "..."}), + CredentialFixture(VisualStudioCodeCredential, ctor_patch=patch(GET_USER_SETTINGS, lambda: {})), +) + +all_fixtures = pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) + + +@all_fixtures +def test_close(fixture): + transport = MagicMock() + credential = fixture(transport=transport) + assert not transport.__enter__.called + assert not transport.__exit__.called + + credential.close() + assert not transport.__enter__.called + assert transport.__exit__.call_count == 1 + + +@all_fixtures +def test_context_manager(fixture): + transport = MagicMock() + credential = fixture(transport=transport) + + with credential: + assert transport.__enter__.call_count == 1 + assert not transport.__exit__.called + + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 + + +@all_fixtures +def test_exit_args(fixture): + transport = MagicMock() + credential = fixture(transport=transport) + expected_args = ("type", "value", "traceback") + credential.__exit__(*expected_args) + transport.__exit__.assert_called_once_with(*expected_args) + + +@pytest.mark.parametrize( + "cls", (AzureCliCredential, AzureApplicationCredential, AzurePowerShellCredential, EnvironmentCredential, SharedTokenCacheCredential) +) +def test_no_op(cls): + """Credentials that don't allow custom transports, or require initialization or optional config, should have no-op methods""" + credential = cls() + with credential: + pass + credential.close() diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index 31c2ba1a71b4..a03a16cb85cc 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -24,9 +24,32 @@ from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache try: - from unittest.mock import Mock, patch + from unittest.mock import MagicMock, Mock, patch except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore + from mock import MagicMock, Mock, patch # type: ignore + + +def test_close(): + transport = MagicMock() + credential = DefaultAzureCredential(transport=transport) + assert not transport.__enter__.called + assert not transport.__exit__.called + + credential.close() + assert not transport.__enter__.called + assert transport.__exit__.called # call count depends on the chain's composition + + +def test_context_manager(): + transport = MagicMock() + credential = DefaultAzureCredential(transport=transport) + + with credential: + assert transport.__enter__.called # call count depends on the chain's composition + assert not transport.__exit__.called + + assert transport.__enter__.called + assert transport.__exit__.called def test_iterates_only_once(): diff --git a/sdk/identity/azure-identity/tests/test_environment_credential.py b/sdk/identity/azure-identity/tests/test_environment_credential.py index 35ce51045ef7..19dc13484342 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential.py @@ -9,7 +9,7 @@ from azure.identity._constants import EnvironmentVariables import pytest -from helpers import mock, mock_response, Request, validating_transport +from helpers import mock ALL_VARIABLES = { @@ -20,6 +20,15 @@ } +def test_close_incomplete_configuration(): + EnvironmentCredential().close() + + +def test_context_manager_incomplete_configuration(): + with EnvironmentCredential(): + pass + + def test_incomplete_configuration(): """get_token should raise CredentialUnavailableError for incomplete configuration.""" diff --git a/sdk/identity/azure-identity/tests/test_environment_credential_async.py b/sdk/identity/azure-identity/tests/test_environment_credential_async.py index 634dfcb8a2dd..43fd568b299f 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential_async.py @@ -12,10 +12,46 @@ import pytest from helpers import mock_response, Request -from helpers_async import async_validating_transport +from helpers_async import async_validating_transport, AsyncMockTransport from test_environment_credential import ALL_VARIABLES +@pytest.mark.asyncio +async def test_close(): + transport = AsyncMockTransport() + with mock.patch.dict("os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): + credential = EnvironmentCredential(transport=transport) + assert transport.__aexit__.call_count == 0 + + await credential.close() + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager(): + transport = AsyncMockTransport() + with mock.patch.dict("os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): + credential = EnvironmentCredential(transport=transport) + + async with credential: + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 0 + + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_close_incomplete_configuration(): + await EnvironmentCredential().close() + + +@pytest.mark.asyncio +async def test_context_manager_incomplete_configuration(): + async with EnvironmentCredential(): + pass + + @pytest.mark.asyncio async def test_incomplete_configuration(): """get_token should raise CredentialUnavailableError for incomplete configuration.""" diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index eb199f08b368..994b25b0d9f0 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -21,6 +21,56 @@ from helpers import build_aad_response, validating_transport, mock_response, Request MANAGED_IDENTITY_ENVIRON = "azure.identity._credentials.managed_identity.os.environ" +ALL_ENVIRONMENTS = ( + {EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # App Service + {EnvironmentVariables.MSI_ENDPOINT: "..."}, # Cloud Shell + { # Service Fabric + EnvironmentVariables.IDENTITY_ENDPOINT: "...", + EnvironmentVariables.IDENTITY_HEADER: "...", + EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...", + }, + {EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc + { # token exchange + EnvironmentVariables.AZURE_CLIENT_ID: "...", + EnvironmentVariables.AZURE_TENANT_ID: "...", + EnvironmentVariables.TOKEN_FILE_PATH: __file__, + }, + {}, # IMDS +) + + +@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) +def test_close(environ): + transport = mock.MagicMock() + with mock.patch.dict("os.environ", environ, clear=True): + credential = ManagedIdentityCredential(transport=transport) + assert transport.__exit__.call_count == 0 + + credential.close() + assert transport.__exit__.call_count == 1 + + +@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) +def test_context_manager(environ): + transport = mock.MagicMock() + with mock.patch.dict("os.environ", environ, clear=True): + credential = ManagedIdentityCredential(transport=transport) + + with credential: + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 0 + + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 + + +def test_close_incomplete_configuration(): + ManagedIdentityCredential().close() + + +def test_context_manager_incomplete_configuration(): + with ManagedIdentityCredential(): + pass ALL_ENVIRONMENTS = ( diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index 14f616c404cc..0f2f5088ea44 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -16,7 +16,7 @@ import pytest from helpers import build_aad_response, mock_response, Request -from helpers_async import async_validating_transport, AsyncMockTransport, get_completed_future +from helpers_async import async_validating_transport, AsyncMockTransport from test_managed_identity import ALL_ENVIRONMENTS @@ -94,6 +94,17 @@ async def test_context_manager(environ): assert transport.__aexit__.call_count == 1 +@pytest.mark.asyncio +async def test_close_incomplete_configuration(): + await ManagedIdentityCredential().close() + + +@pytest.mark.asyncio +async def test_context_manager_incomplete_configuration(): + async with ManagedIdentityCredential(): + pass + + @pytest.mark.asyncio async def test_cloud_shell(): """Cloud Shell environment: only MSI_ENDPOINT set""" diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index 4cd3f8b1e6d4..5081825cebb5 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -25,9 +25,9 @@ from six.moves.urllib_parse import urlparse try: - from unittest.mock import Mock, patch + from unittest.mock import MagicMock, Mock, patch except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore + from mock import MagicMock, Mock, patch # type: ignore from helpers import ( build_aad_response, @@ -41,6 +41,37 @@ ) +def test_close(): + transport = MagicMock() + credential = SharedTokenCacheCredential(transport=transport, _cache=TokenCache()) + with pytest.raises(CredentialUnavailableError): + credential.get_token('scope') + + assert not transport.__enter__.called + assert not transport.__exit__.called + + credential.close() + assert not transport.__enter__.called + assert transport.__exit__.call_count == 1 + + +def test_context_manager(): + transport = MagicMock() + credential = SharedTokenCacheCredential(transport=transport, _cache=TokenCache()) + with pytest.raises(CredentialUnavailableError): + credential.get_token('scope') + + assert not transport.__enter__.called + assert not transport.__exit__.called + + with credential: + assert transport.__enter__.call_count == 1 + assert not transport.__exit__.called + + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 + + def test_tenant_id_validation(): """The credential should raise ValueError when given an invalid tenant_id""" From 9aa54a2eb732536516e9db17ae4af98eac53184a Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 28 Jul 2021 13:56:14 -0700 Subject: [PATCH 08/14] tweak test module --- .../tests/test_context_manager.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_context_manager.py b/sdk/identity/azure-identity/tests/test_context_manager.py index 295f657a6bcc..eab3ade98bcf 100644 --- a/sdk/identity/azure-identity/tests/test_context_manager.py +++ b/sdk/identity/azure-identity/tests/test_context_manager.py @@ -30,13 +30,12 @@ class CredentialFixture: - def __init__(self, cls, default_kwargs=None, ctor_patch=None, get_token_patch=None): + def __init__(self, cls, default_kwargs=None, ctor_patch=None): self.cls = cls - self.get_token_patch = get_token_patch or MagicMock() self._default_kwargs = default_kwargs or {} self._ctor_patch = ctor_patch or MagicMock() - def __call__(self, **kwargs): + def get_credential(self, **kwargs): with self._ctor_patch: return self.cls(**dict(self._default_kwargs, **kwargs)) @@ -51,20 +50,21 @@ def __call__(self, **kwargs): CredentialFixture(DeviceCodeCredential), CredentialFixture( EnvironmentCredential, - ctor_patch=patch.dict("os.environ", {var: ".." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True), + ctor_patch=patch.dict( + EnvironmentCredential.__module__ + ".os.environ", + {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, + ), ), CredentialFixture(InteractiveBrowserCredential), CredentialFixture(UsernamePasswordCredential, {"client_id": "...", "username": "...", "password": "..."}), CredentialFixture(VisualStudioCodeCredential, ctor_patch=patch(GET_USER_SETTINGS, lambda: {})), ) -all_fixtures = pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) - -@all_fixtures +@pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) def test_close(fixture): transport = MagicMock() - credential = fixture(transport=transport) + credential = fixture.get_credential(transport=transport) assert not transport.__enter__.called assert not transport.__exit__.called @@ -73,10 +73,10 @@ def test_close(fixture): assert transport.__exit__.call_count == 1 -@all_fixtures +@pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) def test_context_manager(fixture): transport = MagicMock() - credential = fixture(transport=transport) + credential = fixture.get_credential(transport=transport) with credential: assert transport.__enter__.call_count == 1 @@ -86,17 +86,24 @@ def test_context_manager(fixture): assert transport.__exit__.call_count == 1 -@all_fixtures +@pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) def test_exit_args(fixture): transport = MagicMock() - credential = fixture(transport=transport) + credential = fixture.get_credential(transport=transport) expected_args = ("type", "value", "traceback") credential.__exit__(*expected_args) transport.__exit__.assert_called_once_with(*expected_args) @pytest.mark.parametrize( - "cls", (AzureCliCredential, AzureApplicationCredential, AzurePowerShellCredential, EnvironmentCredential, SharedTokenCacheCredential) + "cls", + ( + AzureCliCredential, + AzureApplicationCredential, + AzurePowerShellCredential, + EnvironmentCredential, + SharedTokenCacheCredential, + ), ) def test_no_op(cls): """Credentials that don't allow custom transports, or require initialization or optional config, should have no-op methods""" From 71c8f1998a23782e62cb5c4e9c9a65ca6da8b8a3 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 28 Jul 2021 16:02:56 -0700 Subject: [PATCH 09/14] work around apparent statefulness of mock.patch.dict on 3.6 --- .../tests/test_context_manager.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_context_manager.py b/sdk/identity/azure-identity/tests/test_context_manager.py index eab3ade98bcf..e5f6e59a8013 100644 --- a/sdk/identity/azure-identity/tests/test_context_manager.py +++ b/sdk/identity/azure-identity/tests/test_context_manager.py @@ -30,13 +30,14 @@ class CredentialFixture: - def __init__(self, cls, default_kwargs=None, ctor_patch=None): + def __init__(self, cls, default_kwargs=None, ctor_patch_factory=None): self.cls = cls self._default_kwargs = default_kwargs or {} - self._ctor_patch = ctor_patch or MagicMock() + self._ctor_patch_factory = ctor_patch_factory or MagicMock def get_credential(self, **kwargs): - with self._ctor_patch: + patch = self._ctor_patch_factory() + with patch: return self.cls(**dict(self._default_kwargs, **kwargs)) @@ -50,18 +51,20 @@ def get_credential(self, **kwargs): CredentialFixture(DeviceCodeCredential), CredentialFixture( EnvironmentCredential, - ctor_patch=patch.dict( + ctor_patch_factory=lambda: patch.dict( EnvironmentCredential.__module__ + ".os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, ), ), CredentialFixture(InteractiveBrowserCredential), CredentialFixture(UsernamePasswordCredential, {"client_id": "...", "username": "...", "password": "..."}), - CredentialFixture(VisualStudioCodeCredential, ctor_patch=patch(GET_USER_SETTINGS, lambda: {})), + CredentialFixture(VisualStudioCodeCredential, ctor_patch_factory=lambda: patch(GET_USER_SETTINGS, lambda: {})), ) +all_fixtures = pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) -@pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) + +@all_fixtures def test_close(fixture): transport = MagicMock() credential = fixture.get_credential(transport=transport) @@ -73,7 +76,7 @@ def test_close(fixture): assert transport.__exit__.call_count == 1 -@pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) +@all_fixtures def test_context_manager(fixture): transport = MagicMock() credential = fixture.get_credential(transport=transport) @@ -86,7 +89,7 @@ def test_context_manager(fixture): assert transport.__exit__.call_count == 1 -@pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) +@all_fixtures def test_exit_args(fixture): transport = MagicMock() credential = fixture.get_credential(transport=transport) From a57278669c3950a22f237ee124d78e97c625ac45 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 9 Aug 2021 09:37:36 -0700 Subject: [PATCH 10/14] TokenExchange/ClientAssertionCredential --- .../azure/identity/_credentials/client_assertion.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py index 013307c3e39b..8c3dfefcd8a0 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py @@ -34,6 +34,17 @@ def __init__(self, tenant_id, client_id, get_assertion, **kwargs): self._client = AadClient(tenant_id, client_id, **kwargs) super(ClientAssertionCredential, self).__init__(**kwargs) + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + def _acquire_token_silently(self, *scopes, **kwargs): # type: (*str, **Any) -> Optional[AccessToken] return self._client.get_cached_access_token(scopes, **kwargs) From 8384a980a8fdff570644cf937a08ae20958098f0 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 10 Aug 2021 15:14:07 -0700 Subject: [PATCH 11/14] merge changelog --- sdk/identity/azure-identity/CHANGELOG.md | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 20841172a655..d356c259187d 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -1,9 +1,22 @@ # Release History -## 1.7.0b3 (Unreleased) +## 1.7.0b4 (Unreleased) ### Features Added +### Breaking Changes + +### Bugs Fixed + +### Other Changes +- Added context manager methods and `close()` to credentials in the + `azure.identity` namespace. At the end of a `with` block, or when `close()` + is called, these credentials close their underlying transport sessions. + ([#18798](https://github.com/Azure/azure-sdk-for-python/issues/18798)) + + +## 1.7.0b3 (2021-08-10) + ### Breaking Changes > These changes do not impact the API of stable versions such as 1.6.0. > Only code written against a beta version such as 1.7.0b1 may be affected. @@ -29,10 +42,6 @@ workarounds when importing transitive dependencies such as pywin32 fails ([#19989](https://github.com/Azure/azure-sdk-for-python/issues/19989)) -- Added context manager methods and `close()` to credentials in the - `azure.identity` namespace. At the end of a `with` block, or when `close()` - is called, these credentials close their underlying transport sessions. - ([#18798](https://github.com/Azure/azure-sdk-for-python/issues/18798)) ## 1.7.0b2 (2021-07-08) From 2672f4faed4aac71f1e8390334d98ab95c68a37a Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 11 Aug 2021 08:29:55 -0700 Subject: [PATCH 12/14] remove redundant tests --- .../azure-identity/tests/test_context_manager.py | 4 +++- .../azure-identity/tests/test_environment_credential.py | 9 --------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_context_manager.py b/sdk/identity/azure-identity/tests/test_context_manager.py index e5f6e59a8013..a777542a414a 100644 --- a/sdk/identity/azure-identity/tests/test_context_manager.py +++ b/sdk/identity/azure-identity/tests/test_context_manager.py @@ -110,7 +110,9 @@ def test_exit_args(fixture): ) def test_no_op(cls): """Credentials that don't allow custom transports, or require initialization or optional config, should have no-op methods""" - credential = cls() + with patch.dict("os.environ", {}, clear=True): + credential = cls() + with credential: pass credential.close() diff --git a/sdk/identity/azure-identity/tests/test_environment_credential.py b/sdk/identity/azure-identity/tests/test_environment_credential.py index 19dc13484342..2684263973c4 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential.py @@ -20,15 +20,6 @@ } -def test_close_incomplete_configuration(): - EnvironmentCredential().close() - - -def test_context_manager_incomplete_configuration(): - with EnvironmentCredential(): - pass - - def test_incomplete_configuration(): """get_token should raise CredentialUnavailableError for incomplete configuration.""" From 8dda0637a12eba00be72b1e7b5f360973d942ad7 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 11 Aug 2021 08:40:50 -0700 Subject: [PATCH 13/14] narrower patch scope in tests --- .../test_environment_credential_async.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_environment_credential_async.py b/sdk/identity/azure-identity/tests/test_environment_credential_async.py index 43fd568b299f..7b473841abb6 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential_async.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # ------------------------------------ import itertools -import os from unittest import mock from azure.identity import CredentialUnavailableError @@ -15,11 +14,13 @@ from helpers_async import async_validating_transport, AsyncMockTransport from test_environment_credential import ALL_VARIABLES +ENVIRON = EnvironmentCredential.__module__ + ".os.environ" + @pytest.mark.asyncio async def test_close(): transport = AsyncMockTransport() - with mock.patch.dict("os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): + with mock.patch.dict(ENVIRON, {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): credential = EnvironmentCredential(transport=transport) assert transport.__aexit__.call_count == 0 @@ -30,7 +31,7 @@ async def test_close(): @pytest.mark.asyncio async def test_context_manager(): transport = AsyncMockTransport() - with mock.patch.dict("os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): + with mock.patch.dict(ENVIRON, {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): credential = EnvironmentCredential(transport=transport) async with credential: @@ -43,25 +44,27 @@ async def test_context_manager(): @pytest.mark.asyncio async def test_close_incomplete_configuration(): - await EnvironmentCredential().close() + with mock.patch.dict(ENVIRON, {}, clear=True): + await EnvironmentCredential().close() @pytest.mark.asyncio async def test_context_manager_incomplete_configuration(): - async with EnvironmentCredential(): - pass + with mock.patch.dict(ENVIRON, {}, clear=True): + async with EnvironmentCredential(): + pass @pytest.mark.asyncio async def test_incomplete_configuration(): """get_token should raise CredentialUnavailableError for incomplete configuration.""" - with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch.dict(ENVIRON, {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: await EnvironmentCredential().get_token("scope") for a, b in itertools.combinations(ALL_VARIABLES, 2): # all credentials require at least 3 variables set - with mock.patch.dict(os.environ, {a: "a", b: "b"}, clear=True): + with mock.patch.dict(ENVIRON, {a: "a", b: "b"}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: await EnvironmentCredential().get_token("scope") @@ -78,7 +81,7 @@ def test_passes_authority_argument(credential_name, environment_variables): authority = "authority" - with mock.patch.dict("os.environ", {variable: "foo" for variable in environment_variables}, clear=True): + with mock.patch.dict(ENVIRON, {variable: "foo" for variable in environment_variables}, clear=True): with mock.patch(EnvironmentCredential.__module__ + "." + credential_name) as mock_credential: EnvironmentCredential(authority=authority) @@ -101,7 +104,7 @@ def test_client_secret_configuration(): EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } with mock.patch(EnvironmentCredential.__module__ + ".ClientSecretCredential") as mock_credential: - with mock.patch.dict("os.environ", environment, clear=True): + with mock.patch.dict(ENVIRON, environment, clear=True): EnvironmentCredential(foo=bar) assert mock_credential.call_count == 1 @@ -126,7 +129,7 @@ def test_certificate_configuration(): EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } with mock.patch(EnvironmentCredential.__module__ + ".CertificateCredential") as mock_credential: - with mock.patch.dict("os.environ", environment, clear=True): + with mock.patch.dict(ENVIRON, environment, clear=True): EnvironmentCredential(foo=bar) assert mock_credential.call_count == 1 @@ -163,7 +166,7 @@ async def test_client_secret_environment_credential(): EnvironmentVariables.AZURE_CLIENT_SECRET: secret, EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } - with mock.patch.dict("os.environ", environment, clear=True): + with mock.patch.dict(ENVIRON, environment, clear=True): token = await EnvironmentCredential(transport=transport).get_token("scope") assert token.token == access_token From 2f9b065806507d3f17d8a60a9fd65ae978cab984 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 11 Aug 2021 12:12:01 -0700 Subject: [PATCH 14/14] Thanks, McCoy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: McCoy Patiño <39780829+mccoyp@users.noreply.github.com> --- .../azure-identity/azure/identity/aio/_credentials/azure_arc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py index babc5fb28f63..f17f3b5b5177 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py @@ -32,7 +32,7 @@ def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": return None def get_unavailable_message(self) -> str: - return "Service Fabric managed identity configuration not found in environment" + return "Azure Arc managed identity configuration not found in environment" class ArcChallengeAuthPolicy(AsyncHTTPPolicy):