diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 0bbd7c8b75cd..d356c259187d 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -9,6 +9,11 @@ ### Bugs Fixed ### Other Changes +- Added context manager methods and `close()` to credentials in the + `azure.identity` namespace. At the end of a `with` block, or when `close()` + is called, these credentials close their underlying transport sessions. + ([#18798](https://github.com/Azure/azure-sdk-for-python/issues/18798)) + ## 1.7.0b3 (2021-08-10) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py index 4c2663cd5d88..7da9589097d3 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py @@ -6,45 +6,27 @@ import os from typing import TYPE_CHECKING -from azure.core.credentials import AccessToken from azure.core.pipeline.transport import HttpRequest -from .. import CredentialUnavailableError from .._constants import EnvironmentVariables +from .._internal.managed_identity_base import ManagedIdentityBase from .._internal.managed_identity_client import ManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin if TYPE_CHECKING: from typing import Any, Optional -class AppServiceCredential(GetTokenMixin): - def __init__(self, **kwargs): - # type: (**Any) -> None - super(AppServiceCredential, self).__init__() - +class AppServiceCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] client_args = _get_client_args(**kwargs) if client_args: - self._available = True - self._client = ManagedIdentityClient(**client_args) - else: - self._available = False - - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - if not self._available: - raise CredentialUnavailableError( - message="App Service managed identity configuration not found in environment" - ) - return super(AppServiceCredential, self).get_token(*scopes, **kwargs) - - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) - - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - return self._client.request_token(*scopes, **kwargs) + return ManagedIdentityClient(**client_args) + return None + + def get_unavailable_message(self): + # type: () -> str + return "App Service managed identity configuration not found in environment" def _get_client_args(**kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py index ce7b33f117f2..587547640744 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -44,6 +44,18 @@ def __init__(self, tenant_id, client_id, authorization_code, redirect_uri, **kwa self._redirect_uri = redirect_uri super(AuthorizationCodeCredential, self).__init__() + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for `scopes`. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py index 5fcc5ddfbe9a..ff38ddbc7a83 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py @@ -10,51 +10,42 @@ from azure.core.pipeline.transport import HttpRequest from azure.core.pipeline.policies import HTTPPolicy -from .. import CredentialUnavailableError from .._constants import EnvironmentVariables +from .._internal.managed_identity_base import ManagedIdentityBase from .._internal.managed_identity_client import ManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Optional, Union - from azure.core.credentials import AccessToken + from typing import Any, Optional from azure.core.pipeline import PipelineRequest, PipelineResponse - from azure.core.pipeline.policies import SansIOHTTPPolicy - PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy] - - -class AzureArcCredential(GetTokenMixin): - def __init__(self, **kwargs): - # type: (**Any) -> None - super(AzureArcCredential, self).__init__() +class AzureArcCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT) imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT) - self._available = url and imds - if self._available: - self._client = ManagedIdentityClient( + if url and imds: + return ManagedIdentityClient( _per_retry_policies=[ArcChallengeAuthPolicy()], request_factory=functools.partial(_get_request, url), **kwargs ) + return None - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - if not self._available: - raise CredentialUnavailableError( - message="Azure Arc managed identity configuration not found in environment" - ) - return super(AzureArcCredential, self).get_token(*scopes, **kwargs) + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) + def close(self): + self.__exit__() - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - return self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self): + # type: () -> str + return "Azure Arc managed identity configuration not found in environment" def _get_request(url, scope, identity_config): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py index 111856a3a090..24e160d768c4 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py @@ -44,6 +44,16 @@ class AzureCliCredential(object): def __init__(self, **kwargs): self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False) + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def close(self): + # type: () -> None + """Calling this method is unnecessary.""" + @log_get_token("AzureCliCredential") def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py index ed3befb14462..8ce657c6e261 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py @@ -61,6 +61,16 @@ def __init__(self, **kwargs): # type: (**Any) -> None self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False) + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def close(self): + # type: () -> None + """Calling this method is unnecessary.""" + @log_get_token("AzurePowerShellCredential") def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py index 239adf55d988..35936acb7679 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py @@ -53,6 +53,20 @@ def __init__(self, *credentials): self._successful_credential = None # type: Optional[TokenCredential] self.credentials = credentials + def __enter__(self): + for credential in self.credentials: + credential.__enter__() + return self + + def __exit__(self, *args): + for credential in self.credentials: + credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the transport session of each credential in the chain.""" + self.__exit__() + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (*str, **Any) -> AccessToken """Request a token from each chained credential, in order, returning the first token received. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py index 013307c3e39b..8c3dfefcd8a0 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py @@ -34,6 +34,17 @@ def __init__(self, tenant_id, client_id, get_assertion, **kwargs): self._client = AadClient(tenant_id, client_id, **kwargs) super(ClientAssertionCredential, self).__init__(**kwargs) + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + def _acquire_token_silently(self, *scopes, **kwargs): # type: (*str, **Any) -> Optional[AccessToken] return self._client.get_cached_access_token(scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py index 17e10feec6d9..a9fb5ed96432 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py @@ -8,46 +8,27 @@ from azure.core.pipeline.transport import HttpRequest -from .. import CredentialUnavailableError from .._constants import EnvironmentVariables from .._internal.managed_identity_client import ManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin +from .._internal.managed_identity_base import ManagedIdentityBase if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class CloudShellCredential(GetTokenMixin): - def __init__(self, **kwargs): - # type: (**Any) -> None - super(CloudShellCredential, self).__init__() +class CloudShellCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT) if url: - self._available = True - self._client = ManagedIdentityClient( - request_factory=functools.partial(_get_request, url), - base_headers={"Metadata": "true"}, - **kwargs + return ManagedIdentityClient( + request_factory=functools.partial(_get_request, url), base_headers={"Metadata": "true"}, **kwargs ) - else: - self._available = False - - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - if not self._available: - raise CredentialUnavailableError( - message="Cloud Shell managed identity configuration not found in environment" - ) - return super(CloudShellCredential, self).get_token(*scopes, **kwargs) - - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) + return None - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - return self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self): + # type: () -> str + return "Cloud Shell managed identity configuration not found in environment" def _get_request(url, scope, identity_config): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py index 97aecf8c03f1..51e77d160372 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py @@ -101,6 +101,20 @@ def __init__(self, **kwargs): else: _LOGGER.info("No environment configuration found.") + def __enter__(self): + if self._credential: + self._credential.__enter__() + return self + + def __exit__(self, *args): + if self._credential: + self._credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + @log_get_token("EnvironmentCredential") def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index ca385eeecb7e..d642dd063b21 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -56,6 +56,16 @@ def __init__(self, **kwargs): self._error_message = None # type: Optional[str] self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + self.__exit__() + def _acquire_token_silently(self, *scopes, **kwargs): # type: (*str, **Any) -> Optional[AccessToken] return self._client.get_cached_token(*scopes) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index d0a0acef7931..3a291c1e1df3 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -81,6 +81,18 @@ def __init__(self, **kwargs): _LOGGER.info("%s will use IMDS", self.__class__.__name__) self._credential = ImdsCredential(**kwargs) + def __enter__(self): + self._credential.__enter__() + return self + + def __exit__(self, *args): + self._credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + @log_get_token("ManagedIdentityCredential") def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py index 0594a06f84ac..cb8c29a84ccc 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py @@ -8,43 +8,25 @@ from azure.core.pipeline.transport import HttpRequest -from .. import CredentialUnavailableError from .._constants import EnvironmentVariables +from .._internal.managed_identity_base import ManagedIdentityBase from .._internal.managed_identity_client import ManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class ServiceFabricCredential(GetTokenMixin): - def __init__(self, **kwargs): - # type: (**Any) -> None - super(ServiceFabricCredential, self).__init__() - +class ServiceFabricCredential(ManagedIdentityBase): + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] client_args = _get_client_args(**kwargs) if client_args: - self._available = True - self._client = ManagedIdentityClient(**client_args) - else: - self._available = False - - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - if not self._available: - raise CredentialUnavailableError( - message="Service Fabric managed identity configuration not found in environment" - ) - return super(ServiceFabricCredential, self).get_token(*scopes, **kwargs) - - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) + return ManagedIdentityClient(**client_args) + return None - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - return self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self): + # type: () -> str + return "Service Fabric managed identity configuration not found in environment" def _get_client_args(**kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 7d49c9dcac3f..906aaad174fc 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -48,6 +48,18 @@ def __init__(self, username=None, **kwargs): else: self._credential = _SharedTokenCacheCredential(username=username, **kwargs) + def __enter__(self): + self._credential.__enter__() + return self + + def __exit__(self, *args): + self._credential.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + @log_get_token("SharedTokenCacheCredential") def get_token(self, *scopes, **kwargs): # type (*str, **Any) -> AccessToken @@ -84,6 +96,15 @@ def supported(): class _SharedTokenCacheCredential(SharedTokenCacheBase): """The original SharedTokenCacheCredential, which doesn't use msal.ClientApplication""" + def __enter__(self): + if self._client: + self._client.__enter__() + return self + + def __exit__(self, *args): + if self._client: + self._client.__exit__(*args) + def get_token(self, *scopes, **kwargs): # type (*str, **Any) -> AccessToken if not scopes: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py index e6996085c199..b1aa1ade8c46 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py @@ -41,6 +41,13 @@ def __init__(self, authentication_record, **kwargs): self._client = MsalClient(**kwargs) self._initialized = False + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type (*str, **Any) -> AccessToken if not scopes: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py index 904766782bf7..cd6866f319da 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py @@ -125,6 +125,20 @@ class VisualStudioCodeCredential(_VSCodeCredentialBase, GetTokenMixin): user's home tenant or the tenant configured by **tenant_id** or VS Code's user settings. """ + def __enter__(self): + if self._client: + self._client.__enter__() + return self + + def __exit__(self, *args): + if self._client: + self._client.__exit__(*args) + + def close(self): + # type: () -> None + """Close the credential's transport session.""" + self.__exit__() + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for `scopes` as the user currently signed in to Visual Studio Code. diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index 79f986fca8a5..fb27461d8eb4 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -17,6 +17,17 @@ class AadClient(AadClientBase): + def __enter__(self): + self._pipeline.__enter__() + return self + + def __exit__(self, *args): + self._pipeline.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs): # type: (Iterable[str], str, str, Optional[str], **Any) -> AccessToken request = self._get_auth_code_request( diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py new file mode 100644 index 000000000000..fba0c90e4071 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py @@ -0,0 +1,62 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import abc +from typing import cast, TYPE_CHECKING + +from .. import CredentialUnavailableError +from .._internal.managed_identity_client import ManagedIdentityClient +from .._internal.get_token_mixin import GetTokenMixin + +if TYPE_CHECKING: + from typing import Any, Optional + from azure.core.credentials import AccessToken + + +class ManagedIdentityBase(GetTokenMixin): + """Base class for internal credentials using ManagedIdentityClient""" + + def __init__(self, **kwargs): + # type: (**Any) -> None + super(ManagedIdentityBase, self).__init__() + self._client = self.get_client(**kwargs) + + @abc.abstractmethod + def get_client(self, **kwargs): + # type: (**Any) -> Optional[ManagedIdentityClient] + pass + + @abc.abstractmethod + def get_unavailable_message(self): + # type: () -> str + pass + + def __enter__(self): + if self._client: + self._client.__enter__() + return self + + def __exit__(self, *args): + if self._client: + self._client.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + + def get_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return super(ManagedIdentityBase, self).get_token(*scopes, **kwargs) + + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] + # casting because mypy can't determine that these methods are called + # only by get_token, which raises when self._client is None + return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes) + + def _request_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py index b593f7394076..2b2164f9a382 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py @@ -104,6 +104,17 @@ def _build_pipeline(self, **kwargs): class ManagedIdentityClient(ManagedIdentityClientBase): + def __enter__(self): + self._pipeline.__enter__() + return self + + def __exit__(self, *args): + self._pipeline.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + def request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken resource = _scopes_to_resource(*scopes) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py index 87fdcde5a0e4..faec84d439d7 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py @@ -73,6 +73,17 @@ def __init__(self, **kwargs): # pylint:disable=missing-client-constructor-param self._local = threading.local() self._pipeline = build_pipeline(**kwargs) + def __enter__(self): + self._pipeline.__enter__() + return self + + def __exit__(self, *args): + self._pipeline.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + def post(self, url, params=None, data=None, headers=None, **kwargs): # pylint:disable=unused-argument # type: (str, Optional[Dict[str, str]], RequestData, Optional[Dict[str, str]], **Any) -> MsalResponse request = HttpRequest("POST", url, headers=headers) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index 3732f1595313..8ac10bbd687d 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -50,6 +50,17 @@ def __init__(self, client_id, client_credential=None, **kwargs): super(MsalCredential, self).__init__() + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + # type: () -> None + self.__exit__() + def _get_app(self, **kwargs): # type: (**Any) -> msal.ClientApplication tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py index e09f35006c1e..b41274042b11 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py @@ -4,46 +4,20 @@ # ------------------------------------ from typing import TYPE_CHECKING -from .._internal import AsyncContextManager +from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin -from ... import CredentialUnavailableError from ..._credentials.app_service import _get_client_args if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class AppServiceCredential(AsyncContextManager, GetTokenMixin): - def __init__(self, **kwargs: "Any") -> None: - super(AppServiceCredential, self).__init__() - +class AppServiceCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": client_args = _get_client_args(**kwargs) if client_args: - self._available = True - self._client = AsyncManagedIdentityClient(**client_args) - else: - self._available = False - - async def __aenter__(self): - if self._available: - await self._client.__aenter__() - return self - - async def close(self) -> None: - await self._client.close() - - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - if not self._available: - raise CredentialUnavailableError( - message="App Service managed identity configuration not found in environment" - ) - - return await super().get_token(*scopes, **kwargs) - - async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) + return AsyncManagedIdentityClient(**client_args) + return None - async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - return await self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self) -> str: + return "App Service managed identity configuration not found in environment" diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py index 0b75b63af44a..f17f3b5b5177 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py @@ -8,55 +8,31 @@ from azure.core.pipeline.policies import AsyncHTTPPolicy -from .._internal import AsyncContextManager +from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin -from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables from ..._credentials.azure_arc import _get_request, _get_secret_key if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports from typing import Any, Optional - from azure.core.credentials import AccessToken from azure.core.pipeline import PipelineRequest, PipelineResponse -class AzureArcCredential(AsyncContextManager, GetTokenMixin): - def __init__(self, **kwargs: "Any") -> None: - super().__init__() - +class AzureArcCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT) imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT) - self._available = url and imds - if self._available: - self._client = AsyncManagedIdentityClient( + if url and imds: + return AsyncManagedIdentityClient( _per_retry_policies=[ArcChallengeAuthPolicy()], request_factory=functools.partial(_get_request, url), **kwargs ) + return None - async def __aenter__(self): - if self._available: - await self._client.__aenter__() - return self - - async def close(self) -> None: - await self._client.close() - - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - if not self._available: - raise CredentialUnavailableError( - message="Service Fabric managed identity configuration not found in environment" - ) - - return await super().get_token(*scopes, **kwargs) - - async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) - - async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - return await self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self) -> str: + return "Azure Arc managed identity configuration not found in environment" class ArcChallengeAuthPolicy(AsyncHTTPPolicy): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py index 512fc621a21b..140d9afe8bbf 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py @@ -6,49 +6,23 @@ import os from typing import TYPE_CHECKING -from .._internal import AsyncContextManager -from .._internal.get_token_mixin import GetTokenMixin +from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient -from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables from ..._credentials.cloud_shell import _get_request if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class CloudShellCredential(AsyncContextManager, GetTokenMixin): - def __init__(self, **kwargs: "Any") -> None: - super(CloudShellCredential, self).__init__() +class CloudShellCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT) if url: - self._available = True - self._client = AsyncManagedIdentityClient( - request_factory=functools.partial(_get_request, url), - base_headers={"Metadata": "true"}, - **kwargs, + return AsyncManagedIdentityClient( + request_factory=functools.partial(_get_request, url), base_headers={"Metadata": "true"}, **kwargs ) - else: - self._available = False + return None - async def __aenter__(self): - if self._available: - await self._client.__aenter__() - return self - - async def close(self) -> None: - await self._client.close() - - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - if not self._available: - raise CredentialUnavailableError( - message="Cloud Shell managed identity configuration not found in environment" - ) - return await super(CloudShellCredential, self).get_token(*scopes, **kwargs) - - async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) - - async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - return await self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self) -> str: + return "Cloud Shell managed identity configuration not found in environment" diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py index 5e8de07d763e..1ab3477bab90 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py @@ -4,46 +4,20 @@ # ------------------------------------ from typing import TYPE_CHECKING -from .._internal import AsyncContextManager +from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient -from .._internal.get_token_mixin import GetTokenMixin -from ... import CredentialUnavailableError from ..._credentials.service_fabric import _get_client_args if TYPE_CHECKING: from typing import Any, Optional - from azure.core.credentials import AccessToken -class ServiceFabricCredential(AsyncContextManager, GetTokenMixin): - def __init__(self, **kwargs: "Any") -> None: - super().__init__() - +class ServiceFabricCredential(AsyncManagedIdentityBase): + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": client_args = _get_client_args(**kwargs) if client_args: - self._available = True - self._client = AsyncManagedIdentityClient(**client_args) - else: - self._available = False - - async def __aenter__(self): - if self._available: - await self._client.__aenter__() - return self - - async def close(self) -> None: - await self._client.close() - - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - if not self._available: - raise CredentialUnavailableError( - message="Service Fabric managed identity configuration not found in environment" - ) - - return await super().get_token(*scopes, **kwargs) - - async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) + return AsyncManagedIdentityClient(**client_args) + return None - async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - return await self._client.request_token(*scopes, **kwargs) + def get_unavailable_message(self) -> str: + return "Service Fabric managed identity configuration not found in environment" diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py new file mode 100644 index 000000000000..63d8b6db9b49 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py @@ -0,0 +1,56 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import abc +from typing import cast, TYPE_CHECKING + +from . import AsyncContextManager +from .get_token_mixin import GetTokenMixin +from .managed_identity_client import AsyncManagedIdentityClient +from ... import CredentialUnavailableError + +if TYPE_CHECKING: + from typing import Any, Optional + from azure.core.credentials import AccessToken + + +class AsyncManagedIdentityBase(AsyncContextManager, GetTokenMixin): + """Base class for internal credentials using AsyncManagedIdentityClient""" + + def __init__(self, **kwargs: "Any") -> None: + super().__init__() + self._client = self.get_client(**kwargs) + + @abc.abstractmethod + def get_client(self, **kwargs: "Any") -> "Optional[AsyncManagedIdentityClient]": + pass + + @abc.abstractmethod + def get_unavailable_message(self) -> str: + pass + + async def __aenter__(self): + if self._client: + await self._client.__aenter__() + return self + + async def __aexit__(self, *args): + if self._client: + await self._client.__aexit__(*args) + + async def close(self) -> None: + await self.__aexit__() + + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return await super().get_token(*scopes, **kwargs) + + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + # casting because mypy can't determine that these methods are called + # only by get_token, which raises when self._client is None + return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes) + + async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/tests/test_chained_credential.py b/sdk/identity/azure-identity/tests/test_chained_credential.py index 85ce50b5035a..db6da5dc80c7 100644 --- a/sdk/identity/azure-identity/tests/test_chained_credential.py +++ b/sdk/identity/azure-identity/tests/test_chained_credential.py @@ -3,9 +3,9 @@ # Licensed under the MIT License. # ------------------------------------ try: - from unittest.mock import Mock + from unittest.mock import MagicMock, Mock except ImportError: # python < 3.3 - from mock import Mock # type: ignore + from mock import MagicMock, Mock # type: ignore from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError @@ -17,6 +17,37 @@ import pytest +def test_close(): + credentials = [MagicMock(close=Mock()) for _ in range(5)] + chain = ChainedTokenCredential(*credentials) + + for credential in credentials: + assert credential.__exit__.call_count == 0 + + chain.close() + + for credential in credentials: + assert credential.__exit__.call_count == 1 + + +def test_context_manager(): + credentials = [MagicMock() for _ in range(5)] + chain = ChainedTokenCredential(*credentials) + + for credential in credentials: + assert credential.__enter__.call_count == 0 + assert credential.__exit__.call_count == 0 + + with chain: + for credential in credentials: + assert credential.__enter__.call_count == 1 + assert credential.__exit__.call_count == 0 + + for credential in credentials: + assert credential.__enter__.call_count == 1 + assert credential.__exit__.call_count == 1 + + def test_error_message(): first_error = "first_error" first_credential = Mock( diff --git a/sdk/identity/azure-identity/tests/test_context_manager.py b/sdk/identity/azure-identity/tests/test_context_manager.py new file mode 100644 index 000000000000..a777542a414a --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_context_manager.py @@ -0,0 +1,118 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +try: + from unittest.mock import MagicMock, patch +except ImportError: + from mock import MagicMock, patch # type: ignore + +from azure.identity import ( + AzureApplicationCredential, + AzureCliCredential, + AzurePowerShellCredential, + AuthorizationCodeCredential, + CertificateCredential, + ClientSecretCredential, + DeviceCodeCredential, + EnvironmentCredential, + InteractiveBrowserCredential, + SharedTokenCacheCredential, + UsernamePasswordCredential, + VisualStudioCodeCredential, +) +from azure.identity._constants import EnvironmentVariables + +import pytest + +from test_certificate_credential import CERT_PATH +from test_vscode_credential import GET_USER_SETTINGS + + +class CredentialFixture: + def __init__(self, cls, default_kwargs=None, ctor_patch_factory=None): + self.cls = cls + self._default_kwargs = default_kwargs or {} + self._ctor_patch_factory = ctor_patch_factory or MagicMock + + def get_credential(self, **kwargs): + patch = self._ctor_patch_factory() + with patch: + return self.cls(**dict(self._default_kwargs, **kwargs)) + + +FIXTURES = ( + CredentialFixture( + AuthorizationCodeCredential, + {kwarg: "..." for kwarg in ("tenant_id", "client_id", "authorization_code", "redirect_uri")}, + ), + CredentialFixture(CertificateCredential, {"tenant_id": "...", "client_id": "...", "certificate_path": CERT_PATH}), + CredentialFixture(ClientSecretCredential, {kwarg: "..." for kwarg in ("tenant_id", "client_id", "client_secret")}), + CredentialFixture(DeviceCodeCredential), + CredentialFixture( + EnvironmentCredential, + ctor_patch_factory=lambda: patch.dict( + EnvironmentCredential.__module__ + ".os.environ", + {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, + ), + ), + CredentialFixture(InteractiveBrowserCredential), + CredentialFixture(UsernamePasswordCredential, {"client_id": "...", "username": "...", "password": "..."}), + CredentialFixture(VisualStudioCodeCredential, ctor_patch_factory=lambda: patch(GET_USER_SETTINGS, lambda: {})), +) + +all_fixtures = pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__) + + +@all_fixtures +def test_close(fixture): + transport = MagicMock() + credential = fixture.get_credential(transport=transport) + assert not transport.__enter__.called + assert not transport.__exit__.called + + credential.close() + assert not transport.__enter__.called + assert transport.__exit__.call_count == 1 + + +@all_fixtures +def test_context_manager(fixture): + transport = MagicMock() + credential = fixture.get_credential(transport=transport) + + with credential: + assert transport.__enter__.call_count == 1 + assert not transport.__exit__.called + + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 + + +@all_fixtures +def test_exit_args(fixture): + transport = MagicMock() + credential = fixture.get_credential(transport=transport) + expected_args = ("type", "value", "traceback") + credential.__exit__(*expected_args) + transport.__exit__.assert_called_once_with(*expected_args) + + +@pytest.mark.parametrize( + "cls", + ( + AzureCliCredential, + AzureApplicationCredential, + AzurePowerShellCredential, + EnvironmentCredential, + SharedTokenCacheCredential, + ), +) +def test_no_op(cls): + """Credentials that don't allow custom transports, or require initialization or optional config, should have no-op methods""" + with patch.dict("os.environ", {}, clear=True): + credential = cls() + + with credential: + pass + credential.close() diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index 31c2ba1a71b4..a03a16cb85cc 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -24,9 +24,32 @@ from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache try: - from unittest.mock import Mock, patch + from unittest.mock import MagicMock, Mock, patch except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore + from mock import MagicMock, Mock, patch # type: ignore + + +def test_close(): + transport = MagicMock() + credential = DefaultAzureCredential(transport=transport) + assert not transport.__enter__.called + assert not transport.__exit__.called + + credential.close() + assert not transport.__enter__.called + assert transport.__exit__.called # call count depends on the chain's composition + + +def test_context_manager(): + transport = MagicMock() + credential = DefaultAzureCredential(transport=transport) + + with credential: + assert transport.__enter__.called # call count depends on the chain's composition + assert not transport.__exit__.called + + assert transport.__enter__.called + assert transport.__exit__.called def test_iterates_only_once(): diff --git a/sdk/identity/azure-identity/tests/test_environment_credential.py b/sdk/identity/azure-identity/tests/test_environment_credential.py index 35ce51045ef7..2684263973c4 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential.py @@ -9,7 +9,7 @@ from azure.identity._constants import EnvironmentVariables import pytest -from helpers import mock, mock_response, Request, validating_transport +from helpers import mock ALL_VARIABLES = { diff --git a/sdk/identity/azure-identity/tests/test_environment_credential_async.py b/sdk/identity/azure-identity/tests/test_environment_credential_async.py index 634dfcb8a2dd..7b473841abb6 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential_async.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # ------------------------------------ import itertools -import os from unittest import mock from azure.identity import CredentialUnavailableError @@ -12,20 +11,60 @@ import pytest from helpers import mock_response, Request -from helpers_async import async_validating_transport +from helpers_async import async_validating_transport, AsyncMockTransport from test_environment_credential import ALL_VARIABLES +ENVIRON = EnvironmentCredential.__module__ + ".os.environ" + + +@pytest.mark.asyncio +async def test_close(): + transport = AsyncMockTransport() + with mock.patch.dict(ENVIRON, {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): + credential = EnvironmentCredential(transport=transport) + assert transport.__aexit__.call_count == 0 + + await credential.close() + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager(): + transport = AsyncMockTransport() + with mock.patch.dict(ENVIRON, {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): + credential = EnvironmentCredential(transport=transport) + + async with credential: + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 0 + + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_close_incomplete_configuration(): + with mock.patch.dict(ENVIRON, {}, clear=True): + await EnvironmentCredential().close() + + +@pytest.mark.asyncio +async def test_context_manager_incomplete_configuration(): + with mock.patch.dict(ENVIRON, {}, clear=True): + async with EnvironmentCredential(): + pass + @pytest.mark.asyncio async def test_incomplete_configuration(): """get_token should raise CredentialUnavailableError for incomplete configuration.""" - with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch.dict(ENVIRON, {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: await EnvironmentCredential().get_token("scope") for a, b in itertools.combinations(ALL_VARIABLES, 2): # all credentials require at least 3 variables set - with mock.patch.dict(os.environ, {a: "a", b: "b"}, clear=True): + with mock.patch.dict(ENVIRON, {a: "a", b: "b"}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: await EnvironmentCredential().get_token("scope") @@ -42,7 +81,7 @@ def test_passes_authority_argument(credential_name, environment_variables): authority = "authority" - with mock.patch.dict("os.environ", {variable: "foo" for variable in environment_variables}, clear=True): + with mock.patch.dict(ENVIRON, {variable: "foo" for variable in environment_variables}, clear=True): with mock.patch(EnvironmentCredential.__module__ + "." + credential_name) as mock_credential: EnvironmentCredential(authority=authority) @@ -65,7 +104,7 @@ def test_client_secret_configuration(): EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } with mock.patch(EnvironmentCredential.__module__ + ".ClientSecretCredential") as mock_credential: - with mock.patch.dict("os.environ", environment, clear=True): + with mock.patch.dict(ENVIRON, environment, clear=True): EnvironmentCredential(foo=bar) assert mock_credential.call_count == 1 @@ -90,7 +129,7 @@ def test_certificate_configuration(): EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } with mock.patch(EnvironmentCredential.__module__ + ".CertificateCredential") as mock_credential: - with mock.patch.dict("os.environ", environment, clear=True): + with mock.patch.dict(ENVIRON, environment, clear=True): EnvironmentCredential(foo=bar) assert mock_credential.call_count == 1 @@ -127,7 +166,7 @@ async def test_client_secret_environment_credential(): EnvironmentVariables.AZURE_CLIENT_SECRET: secret, EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } - with mock.patch.dict("os.environ", environment, clear=True): + with mock.patch.dict(ENVIRON, environment, clear=True): token = await EnvironmentCredential(transport=transport).get_token("scope") assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index eb199f08b368..994b25b0d9f0 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -21,6 +21,56 @@ from helpers import build_aad_response, validating_transport, mock_response, Request MANAGED_IDENTITY_ENVIRON = "azure.identity._credentials.managed_identity.os.environ" +ALL_ENVIRONMENTS = ( + {EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # App Service + {EnvironmentVariables.MSI_ENDPOINT: "..."}, # Cloud Shell + { # Service Fabric + EnvironmentVariables.IDENTITY_ENDPOINT: "...", + EnvironmentVariables.IDENTITY_HEADER: "...", + EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...", + }, + {EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc + { # token exchange + EnvironmentVariables.AZURE_CLIENT_ID: "...", + EnvironmentVariables.AZURE_TENANT_ID: "...", + EnvironmentVariables.TOKEN_FILE_PATH: __file__, + }, + {}, # IMDS +) + + +@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) +def test_close(environ): + transport = mock.MagicMock() + with mock.patch.dict("os.environ", environ, clear=True): + credential = ManagedIdentityCredential(transport=transport) + assert transport.__exit__.call_count == 0 + + credential.close() + assert transport.__exit__.call_count == 1 + + +@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) +def test_context_manager(environ): + transport = mock.MagicMock() + with mock.patch.dict("os.environ", environ, clear=True): + credential = ManagedIdentityCredential(transport=transport) + + with credential: + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 0 + + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 + + +def test_close_incomplete_configuration(): + ManagedIdentityCredential().close() + + +def test_context_manager_incomplete_configuration(): + with ManagedIdentityCredential(): + pass ALL_ENVIRONMENTS = ( diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index 14f616c404cc..0f2f5088ea44 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -16,7 +16,7 @@ import pytest from helpers import build_aad_response, mock_response, Request -from helpers_async import async_validating_transport, AsyncMockTransport, get_completed_future +from helpers_async import async_validating_transport, AsyncMockTransport from test_managed_identity import ALL_ENVIRONMENTS @@ -94,6 +94,17 @@ async def test_context_manager(environ): assert transport.__aexit__.call_count == 1 +@pytest.mark.asyncio +async def test_close_incomplete_configuration(): + await ManagedIdentityCredential().close() + + +@pytest.mark.asyncio +async def test_context_manager_incomplete_configuration(): + async with ManagedIdentityCredential(): + pass + + @pytest.mark.asyncio async def test_cloud_shell(): """Cloud Shell environment: only MSI_ENDPOINT set""" diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index 4cd3f8b1e6d4..5081825cebb5 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -25,9 +25,9 @@ from six.moves.urllib_parse import urlparse try: - from unittest.mock import Mock, patch + from unittest.mock import MagicMock, Mock, patch except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore + from mock import MagicMock, Mock, patch # type: ignore from helpers import ( build_aad_response, @@ -41,6 +41,37 @@ ) +def test_close(): + transport = MagicMock() + credential = SharedTokenCacheCredential(transport=transport, _cache=TokenCache()) + with pytest.raises(CredentialUnavailableError): + credential.get_token('scope') + + assert not transport.__enter__.called + assert not transport.__exit__.called + + credential.close() + assert not transport.__enter__.called + assert transport.__exit__.call_count == 1 + + +def test_context_manager(): + transport = MagicMock() + credential = SharedTokenCacheCredential(transport=transport, _cache=TokenCache()) + with pytest.raises(CredentialUnavailableError): + credential.get_token('scope') + + assert not transport.__enter__.called + assert not transport.__exit__.called + + with credential: + assert transport.__enter__.call_count == 1 + assert not transport.__exit__.called + + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 + + def test_tenant_id_validation(): """The credential should raise ValueError when given an invalid tenant_id""" diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index ec06e41d14b4..e6db05f56e5f 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -225,6 +225,15 @@ def test_adfs(): assert "adfs" in ex.value.message.lower() +def test_custom_cloud_no_authority(): + """The credential is unavailable when VS Code is configured to use a custom cloud with no known authority""" + + cloud_name = "AzureCustomCloud" + credential = get_credential({"azure.cloud": cloud_name}) + with pytest.raises(CredentialUnavailableError, match="authority.*" + cloud_name): + credential.get_token("scope") + + @pytest.mark.parametrize( "cloud,authority", ( diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py index d91332be0320..40c996e39957 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py @@ -213,6 +213,16 @@ async def test_adfs(): assert "adfs" in ex.value.message.lower() +@pytest.mark.asyncio +async def test_custom_cloud_no_authority(): + """The credential is unavailable when VS Code is configured to use a cloud with no known authority""" + + cloud_name = "AzureCustomCloud" + credential = get_credential({"azure.cloud": cloud_name}) + with pytest.raises(CredentialUnavailableError, match="authority.*" + cloud_name): + await credential.get_token("scope") + + @pytest.mark.asyncio @pytest.mark.parametrize( "cloud,authority",