diff --git a/sdk/identity/azure-identity/azure/identity/_constants.py b/sdk/identity/azure-identity/azure/identity/_constants.py index ec443a38acb0..8fce2c892eda 100644 --- a/sdk/identity/azure-identity/azure/identity/_constants.py +++ b/sdk/identity/azure-identity/azure/identity/_constants.py @@ -44,4 +44,5 @@ class EnvironmentVariables: MSI_SECRET = "MSI_SECRET" AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST" + AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION = "AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION" AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME" diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py index f32dbce5face..4c2663cd5d88 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py @@ -38,8 +38,8 @@ def get_token(self, *scopes, **kwargs): ) return super(AppServiceCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] return self._client.get_cached_token(*scopes) def _request_token(self, *scopes, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/_credentials/application.py index 137ab6fdcbc1..c8412ac82014 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/application.py @@ -48,6 +48,10 @@ class AzureApplicationCredential(ChainedTokenCredential): `_ for an overview of managed identities. + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application or user is registered in. When False, which is the default, the credential will acquire tokens + only from the tenant specified by **AZURE_TENANT_ID**. This argument doesn't apply to managed identity + authentication. :keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com", the authority for Azure Public Cloud, which is the default when no value is given for this keyword argument or environment variable AZURE_AUTHORITY_HOST. :class:`~azure.identity.AzureAuthorityHosts` defines authorities for diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py index 0017991e70d4..ce7b33f117f2 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -10,25 +10,29 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Iterable, Optional + from typing import Any, Optional from azure.core.credentials import AccessToken class AuthorizationCodeCredential(GetTokenMixin): """Authenticates by redeeming an authorization code previously obtained from Azure Active Directory. - See https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-auth-code-flow for more information + See `Azure Active Directory documentation + `_ for more information about the authentication flow. - :param str tenant_id: ID of the application's Azure Active Directory tenant. Also called its 'directory' ID. + :param str tenant_id: ID of the application's Azure Active Directory tenant. Also called its "directory" ID. :param str client_id: the application's client ID :param str authorization_code: the authorization code from the user's log-in :param str redirect_uri: The application's redirect URI. Must match the URI used to request the authorization code. - :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. + :keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com", + the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` + defines authorities for other clouds. :keyword str client_secret: One of the application's client secrets. Required only for web apps and web APIs. + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the user is registered in. When False, which is the default, the credential will acquire tokens only from the + user's home tenant or the tenant specified by **tenant_id**. """ def __init__(self, tenant_id, client_id, authorization_code, redirect_uri, **kwargs): @@ -51,16 +55,20 @@ def get_token(self, *scopes, **kwargs): redeeming the authorization code. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. Any error response from Azure Active Directory is available as the error's ``response`` attribute. """ - return super(AuthorizationCodeCredential, self).get_token(*scopes) + # pylint:disable=useless-super-delegation + return super(AuthorizationCodeCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] - return self._client.get_cached_access_token(scopes) + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] + return self._client.get_cached_access_token(scopes, **kwargs) def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py index a827744917fa..03e07ed59d32 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py @@ -58,8 +58,8 @@ def get_token(self, *scopes, **kwargs): ) return super(AzureArcCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] return self._client.get_cached_token(*scopes) def _request_token(self, *scopes, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py index c8f5cc8bd18a..111856a3a090 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py @@ -18,7 +18,7 @@ from azure.core.exceptions import ClientAuthenticationError from .. import CredentialUnavailableError -from .._internal import _scopes_to_resource +from .._internal import _scopes_to_resource, resolve_tenant from .._internal.decorators import log_get_token if TYPE_CHECKING: @@ -35,10 +35,17 @@ class AzureCliCredential(object): """Authenticates by requesting a token from the Azure CLI. This requires previously logging in to Azure via "az login", and will use the CLI's currently logged in identity. + + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the identity logged in to the Azure CLI is registered in. When False, which is the default, the credential will + acquire tokens only from the tenant of the Azure CLI's active subscription. """ + def __init__(self, **kwargs): + self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False) + @log_get_token("AzureCliCredential") - def get_token(self, *scopes, **kwargs): # pylint:disable=no-self-use,unused-argument + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for `scopes`. @@ -46,6 +53,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=no-self-use,unused-arg also handle token caching because this credential doesn't cache the tokens it acquires. :param str scopes: desired scope for the access token. This credential allows only one scope per request. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke the Azure CLI. @@ -54,7 +64,11 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=no-self-use,unused-arg """ resource = _scopes_to_resource(*scopes) - output = _run_command(COMMAND_LINE.format(resource)) + command = COMMAND_LINE.format(resource) + tenant = resolve_tenant("", self._allow_multitenant, **kwargs) + if tenant: + command += " --tenant " + tenant + output = _run_command(command) token = parse_token(output) if not token: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py index ded72457b1e5..ed3befb14462 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py @@ -16,7 +16,7 @@ from .azure_cli import get_safe_working_dir from .. import CredentialUnavailableError -from .._internal import _scopes_to_resource +from .._internal import _scopes_to_resource, resolve_tenant from .._internal.decorators import log_get_token if TYPE_CHECKING: @@ -41,7 +41,7 @@ exit }} -$token = Get-AzAccessToken -ResourceUrl '{}' +$token = Get-AzAccessToken -ResourceUrl '{}'{} Write-Output "`nazsdk%$($token.Token)%$($token.ExpiresOn.ToUnixTimeSeconds())`n" """ @@ -51,10 +51,18 @@ class AzurePowerShellCredential(object): """Authenticates by requesting a token from Azure PowerShell. This requires previously logging in to Azure via "Connect-AzAccount", and will use the currently logged in identity. + + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the identity logged in to Azure PowerShell is registered in. When False, which is the default, the credential + will acquire tokens only from the tenant of Azure PowerShell's active subscription. """ + def __init__(self, **kwargs): + # type: (**Any) -> None + self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False) + @log_get_token("AzurePowerShellCredential") - def get_token(self, *scopes, **kwargs): # pylint:disable=no-self-use,unused-argument + def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for `scopes`. @@ -62,6 +70,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=no-self-use,unused-arg also handle token caching because this credential doesn't cache the tokens it acquires. :param str scopes: desired scope for the access token. This credential allows only one scope per request. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke Azure PowerShell, or @@ -69,8 +80,8 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=no-self-use,unused-arg :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked Azure PowerShell but didn't receive an access token """ - - command_line = get_command_line(scopes) + tenant_id = resolve_tenant("", self._allow_multitenant, **kwargs) + command_line = get_command_line(scopes, tenant_id) output = run_command_line(command_line) token = parse_token(output) return token @@ -128,10 +139,14 @@ def parse_token(output): raise ClientAuthenticationError(message='Unexpected output from Get-AzAccessToken: "{}"'.format(output)) -def get_command_line(scopes): - # type: (Tuple) -> List[str] +def get_command_line(scopes, tenant_id): + # type: (Tuple, str) -> List[str] + if tenant_id: + tenant_argument = " -TenantId " + tenant_id + else: + tenant_argument = "" resource = _scopes_to_resource(*scopes) - script = SCRIPT.format(NO_AZ_ACCOUNT_MODULE, resource) + script = SCRIPT.format(NO_AZ_ACCOUNT_MODULE, resource, tenant_argument) encoded_script = base64.b64encode(script.encode("utf-16-le")).decode() command = "pwsh -NonInteractive -EncodedCommand " + encoded_script diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py index a8e025fb2f3e..6aead5b26f47 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py @@ -31,10 +31,10 @@ class InteractiveBrowserCredential(InteractiveCredential): :func:`~get_token` opens a browser to a login URL provided by Azure Active Directory and authenticates a user there with the authorization code flow, using PKCE (Proof Key for Code Exchange) internally to protect the code. - :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', + :keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com", the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds. - :keyword str tenant_id: an Azure Active Directory tenant ID. Defaults to the 'organizations' tenant, which can + :keyword str tenant_id: an Azure Active Directory tenant ID. Defaults to the "organizations" tenant, which can authenticate work or school accounts. :keyword str client_id: Client ID of the Azure Active Directory application users will sign in to. If unspecified, users will authenticate to an Azure development application. @@ -42,7 +42,7 @@ class InteractiveBrowserCredential(InteractiveCredential): may still log in with a different username. :keyword str redirect_uri: a redirect URI for the application identified by `client_id` as configured in Azure Active Directory, for example "http://localhost:8400". This is only required when passing a value for - `client_id`, and must match a redirect URI in the application's registration. The credential must be able to + **client_id**, and must match a redirect URI in the application's registration. The credential must be able to bind a socket to this URI. :keyword AuthenticationRecord authentication_record: :class:`AuthenticationRecord` returned by :func:`authenticate` :keyword bool disable_automatic_authentication: if True, :func:`get_token` will raise @@ -51,7 +51,10 @@ class InteractiveBrowserCredential(InteractiveCredential): will cache tokens in memory. :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions :keyword int timeout: seconds to wait for the user to complete authentication. Defaults to 300 (5 minutes). - :raises ValueError: invalid `redirect_uri` + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the user is registered in. When False, which is the default, the credential will acquire tokens only from the + user's home tenant or the tenant specified by **tenant_id**. + :raises ValueError: invalid **redirect_uri** """ def __init__(self, **kwargs): @@ -97,7 +100,7 @@ def _request_token(self, *scopes, **kwargs): # get the url the user must visit to authenticate scopes = list(scopes) # type: ignore claims = kwargs.get("claims") - app = self._get_app() + app = self._get_app(**kwargs) flow = app.initiate_auth_code_flow( scopes, redirect_uri=redirect_uri, diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py index db4a15c09d0f..0b04a93d4582 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py @@ -26,23 +26,26 @@ class CertificateCredential(ClientCredentialBase): `_ for more information on configuring certificate authentication. - :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. + :param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID. :param str client_id: the service principal's client ID :param str certificate_path: path to a PEM-encoded certificate file including the private key. If not provided, - `certificate_data` is required. + **certificate_data** is required. :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. + the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` + defines authorities for other clouds. :keyword bytes certificate_data: the bytes of a certificate in PEM format, including the private key :keyword password: The certificate's password. If a unicode string, it will be encoded as UTF-8. If the certificate - requires a different encoding, pass appropriately encoded bytes instead. + requires a different encoding, pass appropriately encoded bytes instead. :paramtype password: str or bytes + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application is registered in. When False, which is the default, the credential will acquire tokens only from + the tenant specified by **tenant_id**. :keyword bool send_certificate_chain: if True, the credential will send the public certificate chain in the x5c - header of each token request's JWT. This is required for Subject Name/Issuer (SNI) authentication. Defaults - to False. + header of each token request's JWT. This is required for Subject Name/Issuer (SNI) authentication. Defaults to + False. :keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential - will cache tokens in memory. + will cache tokens in memory. :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions :keyword ~azure.identity.RegionalAuthority regional_authority: a :class:`~azure.identity.RegionalAuthority` to which the credential will authenticate. This argument should be used only by applications deployed to Azure diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py index 1a46955d8224..9623b0ef8b1d 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py @@ -14,15 +14,18 @@ class ClientSecretCredential(ClientCredentialBase): """Authenticates as a service principal using a client secret. - :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. + :param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID. :param str client_id: the service principal's client ID :param str client_secret: one of the service principal's client secrets - :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. + :keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com", + the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` + defines authorities for other clouds. + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application is registered in. When False, which is the default, the credential will acquire tokens only from + the tenant specified by **tenant_id**. :keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential - will cache tokens in memory. + will cache tokens in memory. :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions :keyword ~azure.identity.RegionalAuthority regional_authority: a :class:`~azure.identity.RegionalAuthority` to which the credential will authenticate. This argument should be used only by applications deployed to Azure diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py index b58b4fa4a3f4..17e10feec6d9 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py @@ -41,8 +41,8 @@ def get_token(self, *scopes, **kwargs): ) return super(CloudShellCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] return self._client.get_cached_token(*scopes) def _request_token(self, *scopes, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index e163c5befa1b..bc74b37ecc89 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -47,9 +47,12 @@ class DefaultAzureCredential(ChainedTokenCredential): This default behavior is configurable with keyword arguments. + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application is registered in. When False, which is the default, the credential will acquire tokens only from + its configured tenant. This argument doesn't apply to managed identity authentication. :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. Managed identities ignore this because they reside in a single cloud. + the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` + defines authorities for other clouds. Managed identities ignore this because they reside in a single cloud. :keyword bool exclude_cli_credential: Whether to exclude the Azure CLI from the credential. Defaults to **False**. :keyword bool exclude_environment_credential: Whether to exclude a service principal configured by environment variables from the credential. Defaults to **False**. @@ -84,7 +87,7 @@ def __init__(self, **kwargs): vscode_tenant_id = kwargs.pop( "visual_studio_code_tenant_id", os.environ.get(EnvironmentVariables.AZURE_TENANT_ID) ) - vscode_args = {} + vscode_args = dict(kwargs) if authority: vscode_args["authority"] = authority if vscode_tenant_id: @@ -130,11 +133,11 @@ def __init__(self, **kwargs): if not exclude_visual_studio_code_credential: credentials.append(VisualStudioCodeCredential(**vscode_args)) if not exclude_cli_credential: - credentials.append(AzureCliCredential()) + credentials.append(AzureCliCredential(**kwargs)) if not exclude_powershell_credential: - credentials.append(AzurePowerShellCredential()) + credentials.append(AzurePowerShellCredential(**kwargs)) if not exclude_interactive_browser_credential: - credentials.append(InteractiveBrowserCredential(tenant_id=interactive_browser_tenant_id)) + credentials.append(InteractiveBrowserCredential(tenant_id=interactive_browser_tenant_id, **kwargs)) super(DefaultAzureCredential, self).__init__(*credentials) @@ -145,6 +148,11 @@ def get_token(self, *scopes, **kwargs): This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + + :rtype: :class:`azure.core.credentials.AccessToken` + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a `message` attribute listing each authentication attempt and its error message. """ diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py index dec7303b334c..657bd5eb5568 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py @@ -32,29 +32,32 @@ class DeviceCodeCredential(InteractiveCredential): convenient because it automatically opens a browser to the login page. :param str client_id: client ID of the application users will authenticate to. When not specified users will - authenticate to an Azure development application. + authenticate to an Azure development application. - :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. - :keyword str tenant_id: an Azure Active Directory tenant ID. Defaults to the 'organizations' tenant, which can - authenticate work or school accounts. **Required for single-tenant applications.** + :keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com", + the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` + defines authorities for other clouds. + :keyword str tenant_id: an Azure Active Directory tenant ID. Defaults to the "organizations" tenant, which can + authenticate work or school accounts. **Required for single-tenant applications.** :keyword int timeout: seconds to wait for the user to authenticate. Defaults to the validity period of the - device code as set by Azure Active Directory, which also prevails when ``timeout`` is longer. + device code as set by Azure Active Directory, which also prevails when **timeout** is longer. :keyword prompt_callback: A callback enabling control of how authentication - instructions are presented. Must accept arguments (``verification_uri``, ``user_code``, ``expires_on``): + instructions are presented. Must accept arguments (``verification_uri``, ``user_code``, ``expires_on``): - - ``verification_uri`` (str) the URL the user must visit - - ``user_code`` (str) the code the user must enter there - - ``expires_on`` (datetime.datetime) the UTC time at which the code will expire - If this argument isn't provided, the credential will print instructions to stdout. + - ``verification_uri`` (str) the URL the user must visit + - ``user_code`` (str) the code the user must enter there + - ``expires_on`` (datetime.datetime) the UTC time at which the code will expire + If this argument isn't provided, the credential will print instructions to stdout. :paramtype prompt_callback: Callable[str, str, ~datetime.datetime] :keyword AuthenticationRecord authentication_record: :class:`AuthenticationRecord` returned by :func:`authenticate` :keyword bool disable_automatic_authentication: if True, :func:`get_token` will raise - :class:`AuthenticationRequiredError` when user interaction is required to acquire a token. Defaults to False. + :class:`AuthenticationRequiredError` when user interaction is required to acquire a token. Defaults to False. :keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential - will cache tokens in memory. + will cache tokens in memory. :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the user is registered in. When False, which is the default, the credential will acquire tokens only from the + user's home tenant or the tenant specified by **tenant_id**. """ def __init__(self, client_id=DEVELOPER_SIGN_ON_CLIENT_ID, **kwargs): @@ -70,7 +73,7 @@ def _request_token(self, *scopes, **kwargs): # MSAL requires scopes be a list scopes = list(scopes) # type: ignore - app = self._get_app() + app = self._get_app(**kwargs) flow = app.initiate_device_flow(scopes) if "error" in flow: raise ClientAuthenticationError( diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py index 538e454392a1..97aecf8c03f1 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py @@ -52,6 +52,10 @@ class EnvironmentCredential(object): - **AZURE_TENANT_ID**: (optional) ID of the service principal's tenant. Also called its 'directory' ID. If not provided, defaults to the 'organizations' tenant, which supports only Azure Active Directory work or school accounts. + + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application or user is registered in. When False, which is the default, the credential will acquire tokens + only from the tenant specified by **AZURE_TENANT_ID**. """ def __init__(self, **kwargs): @@ -105,7 +109,11 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises ~azure.identity.CredentialUnavailableError: environment variable configuration is incomplete """ if not self._credential: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index 8983d120fda3..4380c38d899f 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -51,8 +51,8 @@ def __init__(self, **kwargs): self._error_message = None # type: Optional[str] self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] return self._client.get_cached_token(*scopes) def _request_token(self, *scopes, **kwargs): # pylint:disable=unused-argument diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py index 942a8d02bfaf..0594a06f84ac 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py @@ -38,8 +38,8 @@ def get_token(self, *scopes, **kwargs): ) return super(ServiceFabricCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] return self._client.get_cached_token(*scopes) def _request_token(self, *scopes, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 4ebb45a226c5..b4f7c8e2460e 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -12,7 +12,7 @@ from .. import CredentialUnavailableError from .._constants import DEVELOPER_SIGN_ON_CLIENT_ID -from .._internal import AadClient, validate_tenant_id +from .._internal import AadClient, resolve_tenant, validate_tenant_id from .._internal.decorators import log_get_token, wrap_exceptions from .._internal.msal_client import MsalClient from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase @@ -24,7 +24,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Optional + from typing import Any, Dict, Optional from .. import AuthenticationRecord from .._internal import AadClientBase @@ -46,6 +46,10 @@ class SharedTokenCacheCredential(SharedTokenCacheBase): :keyword cache_persistence_options: configuration for persistent token caching. If not provided, the credential will use the persistent cache shared by Microsoft development applications :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the user is registered in. When False, which is the default, the credential will acquire tokens only from the + user's home tenant or, if a value was given for **authentication_record**, the tenant specified by the + :class:`AuthenticationRecord`. """ def __init__(self, username=None, **kwargs): @@ -56,9 +60,10 @@ def __init__(self, username=None, **kwargs): # authenticate in the tenant that produced the record unless "tenant_id" specifies another self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id validate_tenant_id(self._tenant_id) + self._allow_multitenant = kwargs.pop("allow_multitenant_authentication", False) self._cache = kwargs.pop("_cache", None) - self._app = None - self._client_kwargs = kwargs + self._client_applications = {} # type: Dict[str, PublicClientApplication] + self._msal_client = MsalClient(**kwargs) self._initialized = False else: super(SharedTokenCacheCredential, self).__init__(username=username, **kwargs) @@ -101,7 +106,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account): - token = self._client.obtain_token_by_refresh_token(scopes, refresh_token) + token = self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) return token raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username"))) @@ -119,34 +124,35 @@ def _initialize(self): return self._load_cache() - if self._cache: - if "AZURE_IDENTITY_DISABLE_CP1" in os.environ: - capabilities = None - else: - capabilities = ["CP1"] # able to handle CAE claims challenges - self._app = PublicClientApplication( + self._initialized = True + + def _get_client_application(self, **kwargs): + tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs) + if tenant_id not in self._client_applications: + # CP1 = can handle claims challenges (CAE) + capabilities = None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"] + self._client_applications[tenant_id] = PublicClientApplication( client_id=self._auth_record.client_id, - authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id), + authority="https://{}/{}".format(self._auth_record.authority, tenant_id), token_cache=self._cache, - http_client=MsalClient(**self._client_kwargs), + http_client=self._msal_client, client_capabilities=capabilities ) - - self._initialized = True + return self._client_applications[tenant_id] @wrap_exceptions def _acquire_token_silent(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Silently acquire a token from MSAL. Requires an AuthenticationRecord.""" - # self._auth_record and ._app will not be None when this method is called by get_token - # but should either be None anyway (and to satisfy mypy) we raise - if self._app is None or self._auth_record is None: + # this won't be None when this method is called by get_token but we check anyway to satisfy mypy + if self._auth_record is None: raise CredentialUnavailableError("Initialization failed") result = None - accounts_for_user = self._app.get_accounts(username=self._auth_record.username) + client_application = self._get_client_application(**kwargs) + accounts_for_user = client_application.get_accounts(username=self._auth_record.username) if not accounts_for_user: raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.") @@ -155,7 +161,7 @@ def _acquire_token_silent(self, *scopes, **kwargs): continue now = int(time.time()) - result = self._app.acquire_token_silent_with_error( + result = client_application.acquire_token_silent_with_error( list(scopes), account=account, claims_challenge=kwargs.get("claims") ) if result and "access_token" in result and "expires_in" in result: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/user_password.py b/sdk/identity/azure-identity/azure/identity/_credentials/user_password.py index 99e5c9b1955a..77281a185e6e 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/user_password.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/user_password.py @@ -21,21 +21,25 @@ class UsernamePasswordCredential(InteractiveCredential): a directory admin. This credential can only authenticate work and school accounts; Microsoft accounts are not supported. - See this document for more information about account types: - https://docs.microsoft.com/azure/active-directory/fundamentals/sign-up-organization + See `Azure Active Directory documentation + `_ for more information about + account types. :param str client_id: the application's client ID :param str username: the user's username (usually an email address) :param str password: the user's password - :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. + :keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com", + the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` + defines authorities for other clouds. :keyword str tenant_id: tenant ID or a domain associated with a tenant. If not provided, defaults to the - 'organizations' tenant, which supports only Azure Active Directory work or school accounts. + "organizations" tenant, which supports only Azure Active Directory work or school accounts. :keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential - will cache tokens in memory. + will cache tokens in memory. :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the user is registered in. When False, which is the default, the credential will acquire tokens only from the + user's home tenant or the tenant specified by **tenant_id**. """ def __init__(self, client_id, username, password, **kwargs): @@ -53,7 +57,7 @@ def __init__(self, client_id, username, password, **kwargs): @wrap_exceptions def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> dict - app = self._get_app() + app = self._get_app(**kwargs) return app.acquire_token_by_username_password( username=self._username, password=self._password, diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py index db54a33bf18f..904766782bf7 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py @@ -120,6 +120,9 @@ class VisualStudioCodeCredential(_VSCodeCredentialBase, GetTokenMixin): :keyword str tenant_id: ID of the tenant the credential should authenticate in. Defaults to the "Azure: Tenant" setting in VS Code's user settings or, when that setting has no value, the "organizations" tenant, which supports only Azure Active Directory work or school accounts. + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the user is registered in. When False, which is the default, the credential will acquire tokens only from the + user's home tenant or the tenant configured by **tenant_id** or VS Code's user settings. """ def get_token(self, *scopes, **kwargs): @@ -137,10 +140,10 @@ def get_token(self, *scopes, **kwargs): raise CredentialUnavailableError(message=self._unavailable_reason) return super(VisualStudioCodeCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] self._client = cast(AadClient, self._client) - return self._client.get_cached_access_token(scopes) + return self._client.get_cached_access_token(scopes, **kwargs) def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index 39f554bc47a7..93e032cbb829 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -3,11 +3,17 @@ # Licensed under the MIT License. # ------------------------------------ import os +from typing import TYPE_CHECKING from six.moves.urllib_parse import urlparse +from azure.core.exceptions import ClientAuthenticationError + from .._constants import EnvironmentVariables, KnownAuthorities +if TYPE_CHECKING: + from typing import Any, Optional + def normalize_authority(authority): # type: (str) -> str @@ -43,6 +49,24 @@ def validate_tenant_id(tenant_id): ) +def resolve_tenant(default_tenant, allow_multitenant, tenant_id=None, **_): + # type: (str, bool, Optional[str], **Any) -> str + """Returns the correct tenant for a token request given a credential's configuration""" + if ( + tenant_id is None + or tenant_id == default_tenant + or os.environ.get(EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION) + ): + return default_tenant + + if not allow_multitenant: + raise ClientAuthenticationError( + 'The specified tenant for this token request, "{}", does not match'.format(tenant_id) + + ' the configured tenant, and "allow_multitenant_authentication" is False.' + ) + return tenant_id + + # pylint:disable=wrong-import-position from .aad_client import AadClient from .aad_client_base import AadClientBase @@ -74,5 +98,6 @@ def _scopes_to_resource(*scopes): "get_default_authority", "InteractiveCredential", "normalize_authority", + "resolve_tenant", "wrap_exceptions", ] diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index fb0fecc108a8..93b877afd667 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -34,7 +34,7 @@ class AadClient(AadClientBase): def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs): # type: (Iterable[str], str, str, Optional[str], **Any) -> AccessToken request = self._get_auth_code_request( - scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret + scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) now = int(time.time()) response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs) @@ -42,21 +42,21 @@ def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_ def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs): # type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken - request = self._get_client_certificate_request(scopes, certificate) + request = self._get_client_certificate_request(scopes, certificate, **kwargs) now = int(time.time()) response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs) return self._process_response(response, now) def obtain_token_by_client_secret(self, scopes, secret, **kwargs): # type: (Iterable[str], str, **Any) -> AccessToken - request = self._get_client_secret_request(scopes, secret) + request = self._get_client_secret_request(scopes, secret, **kwargs) now = int(time.time()) response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs) return self._process_response(response, now) def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs): # type: (Iterable[str], str, **Any) -> AccessToken - request = self._get_refresh_token_request(scopes, refresh_token) + request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) now = int(time.time()) response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs) return self._process_response(response, now) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 511b05982a72..ed2d0161e443 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -16,7 +16,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from . import get_default_authority, normalize_authority -from .._constants import DEFAULT_TOKEN_REFRESH_RETRY_DELAY, DEFAULT_REFRESH_OFFSET +from .._internal import resolve_tenant try: from typing import TYPE_CHECKING @@ -44,20 +44,27 @@ class AadClientBase(ABC): _POST = ["POST"] - def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs): - # type: (str, str, Optional[str], Optional[TokenCache], **Any) -> None - authority = normalize_authority(authority) if authority else get_default_authority() - self._token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token")) + def __init__( + self, tenant_id, client_id, authority=None, cache=None, allow_multitenant_authentication=False, **kwargs + ): + # type: (str, str, Optional[str], Optional[TokenCache], bool, **Any) -> None + self._authority = normalize_authority(authority) if authority else get_default_authority() + + self._tenant_id = tenant_id + self._allow_multitenant = allow_multitenant_authentication + self._cache = cache or TokenCache() self._client_id = client_id self._pipeline = self._build_pipeline(**kwargs) - self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY - self._token_refresh_offset = DEFAULT_REFRESH_OFFSET - self._last_refresh_time = 0 - def get_cached_access_token(self, scopes, query=None): - # type: (Iterable[str], Optional[dict]) -> Optional[AccessToken] - tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query) + def get_cached_access_token(self, scopes, **kwargs): + # type: (Iterable[str], **Any) -> Optional[AccessToken] + tenant = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs) + tokens = self._cache.find( + TokenCache.CredentialType.ACCESS_TOKEN, + target=list(scopes), + query={"client_id": self._client_id, "realm": tenant}, + ) for token in tokens: expires_on = int(token["expires_on"]) if expires_on > int(time.time()): @@ -91,8 +98,6 @@ def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs): def _process_response(self, response, request_time): # type: (PipelineResponse, int) -> AccessToken - self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time - content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) if response.http_request.body.get("grant_type") == "refresh_token": @@ -133,17 +138,18 @@ def _process_response(self, response, request_time): # caching is the final step because 'add' mutates 'content' self._cache.add( event={ + "client_id": self._client_id, "response": content, "scope": response.http_request.body["scope"].split(), - "client_id": self._client_id, + "token_endpoint": response.http_request.url, }, now=request_time, ) return token - def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None): - # type: (Iterable[str], str, str, Optional[str]) -> HttpRequest + def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None, **kwargs): + # type: (Iterable[str], str, str, Optional[str], **Any) -> HttpRequest data = { "client_id": self._client_id, "code": code, @@ -154,14 +160,13 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None) if client_secret: data["client_secret"] = client_secret - request = HttpRequest( - "POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data - ) + request = self._post(data, **kwargs) return request - def _get_client_certificate_request(self, scopes, certificate): - # type: (Iterable[str], AadClientCertificate) -> HttpRequest - assertion = self._get_jwt_assertion(certificate) + def _get_client_certificate_request(self, scopes, certificate, **kwargs): + # type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest + audience = self._get_token_url(**kwargs) + assertion = self._get_jwt_assertion(certificate, audience) data = { "client_assertion": assertion, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", @@ -170,26 +175,22 @@ def _get_client_certificate_request(self, scopes, certificate): "scope": " ".join(scopes), } - request = HttpRequest( - "POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data - ) + request = self._post(data, **kwargs) return request - def _get_client_secret_request(self, scopes, secret): - # type: (Iterable[str], str) -> HttpRequest + def _get_client_secret_request(self, scopes, secret, **kwargs): + # type: (Iterable[str], str, **Any) -> HttpRequest data = { "client_id": self._client_id, "client_secret": secret, "grant_type": "client_credentials", "scope": " ".join(scopes), } - request = HttpRequest( - "POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data - ) + request = self._post(data, **kwargs) return request - def _get_jwt_assertion(self, certificate): - # type: (AadClientCertificate) -> str + def _get_jwt_assertion(self, certificate, audience): + # type: (AadClientCertificate, str) -> str now = int(time.time()) header = six.ensure_binary( json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8" @@ -198,7 +199,7 @@ def _get_jwt_assertion(self, certificate): json.dumps( { "jti": str(uuid4()), - "aud": self._token_endpoint, + "aud": audience, "iss": self._client_id, "sub": self._client_id, "nbf": now, @@ -213,8 +214,8 @@ def _get_jwt_assertion(self, certificate): return jwt_bytes.decode("utf-8") - def _get_refresh_token_request(self, scopes, refresh_token): - # type: (Iterable[str], str) -> HttpRequest + def _get_refresh_token_request(self, scopes, refresh_token, **kwargs): + # type: (Iterable[str], str, **Any) -> HttpRequest data = { "grant_type": "refresh_token", "refresh_token": refresh_token, @@ -222,11 +223,19 @@ def _get_refresh_token_request(self, scopes, refresh_token): "client_id": self._client_id, "client_info": 1, # request AAD include home_account_id in its response } - request = HttpRequest( - "POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data - ) + request = self._post(data, **kwargs) return request + def _get_token_url(self, **kwargs): + # type: (**Any) -> str + tenant = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs) + return "/".join((self._authority, tenant, "oauth2/v2.0/token")) + + def _post(self, data, **kwargs): + # type: (dict, **Any) -> HttpRequest + url = self._get_token_url(**kwargs) + return HttpRequest("POST", url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}) + def _scrub_secrets(response): # type: (dict) -> None diff --git a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py index 9b6d9186f49b..2798776ac196 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py @@ -5,8 +5,6 @@ import time from typing import TYPE_CHECKING -import msal - from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from .get_token_mixin import GetTokenMixin @@ -24,7 +22,7 @@ class ClientCredentialBase(MsalCredential, GetTokenMixin): @wrap_exceptions def _acquire_token_silently(self, *scopes, **kwargs): # type: (*str, **Any) -> Optional[AccessToken] - app = self._get_app() + app = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_silent_with_error(list(scopes), account=None, **kwargs) if result and "access_token" in result and "expires_in" in result: @@ -34,7 +32,7 @@ def _acquire_token_silently(self, *scopes, **kwargs): @wrap_exceptions def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> Optional[AccessToken] - app = self._get_app() + app = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_for_client(list(scopes)) if "access_token" not in result: @@ -42,9 +40,3 @@ def _request_token(self, *scopes, **kwargs): raise ClientAuthenticationError(message=message) return AccessToken(result["access_token"], request_time + int(result["expires_in"])) - - def _get_app(self): - # type: () -> msal.ConfidentialClientApplication - if not self._msal_app: - self._msal_app = self._create_app(msal.ConfidentialClientApplication) - return self._msal_app diff --git a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py index c927504d2bd3..d605e92e765e 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py @@ -31,8 +31,8 @@ def __init__(self, *args, **kwargs): super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @abc.abstractmethod - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] """Attempt to acquire an access token from a cache or by redeeming a refresh token""" @abc.abstractmethod @@ -56,20 +56,24 @@ def get_token(self, *scopes, **kwargs): This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks - required data, state, or platform support + required data, state, or platform support :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` - attribute gives a reason. + attribute gives a reason. """ if not scopes: raise ValueError('"get_token" requires at least one scope') try: - token = self._acquire_token_silently(*scopes) + token = self._acquire_token_silently(*scopes, **kwargs) if not token: self._last_request_time = int(time.time()) - token = self._request_token(*scopes) + token = self._request_token(*scopes, **kwargs) elif self._should_refresh(token): try: self._last_request_time = int(time.time()) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index 2eb9653900e9..a095c783c595 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -8,11 +8,9 @@ import base64 import json import logging -import os import time from typing import TYPE_CHECKING -import msal import six from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError @@ -23,6 +21,11 @@ from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError from .._internal import wrap_exceptions +try: + ABC = abc.ABC +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore + if TYPE_CHECKING: # pylint:disable=ungrouped-imports,unused-import from typing import Any, Optional @@ -81,7 +84,7 @@ def _build_auth_record(response): six.raise_from(auth_error, ex) -class InteractiveCredential(MsalCredential): +class InteractiveCredential(MsalCredential, ABC): def __init__(self, **kwargs): self._disable_automatic_authentication = kwargs.pop("disable_automatic_authentication", False) self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] @@ -106,13 +109,17 @@ def get_token(self, *scopes, **kwargs): :param str scopes: desired scopes for the access token. This method requires at least one scope. :keyword str claims: additional claims required in the token, such as those returned in a resource provider's claims challenge following an authorization failure + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks - required data, state, or platform support + required data, state, or platform support :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` - attribute gives a reason. + attribute gives a reason. :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is - configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. + configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. """ if not scopes: message = "'get_token' requires at least one scope" @@ -148,7 +155,10 @@ def get_token(self, *scopes, **kwargs): self._auth_record = _build_auth_record(result) except Exception as ex: # pylint:disable=broad-except _LOGGER.warning( - "%s.get_token failed: %s", self.__class__.__name__, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), + "%s.get_token failed: %s", + self.__class__.__name__, + ex, + exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) raise @@ -188,7 +198,7 @@ def _acquire_token_silent(self, *scopes, **kwargs): result = None claims = kwargs.get("claims") if self._auth_record: - app = self._get_app() + app = self._get_app(**kwargs) for account in app.get_accounts(username=self._auth_record.username): if account.get("home_account_id") != self._auth_record.home_account_id: continue @@ -204,16 +214,6 @@ def _acquire_token_silent(self, *scopes, **kwargs): raise AuthenticationRequiredError(scopes, claims=claims, response=response) raise AuthenticationRequiredError(scopes, claims=claims) - def _get_app(self): - # type: () -> msal.PublicClientApplication - if not self._msal_app: - if "AZURE_IDENTITY_DISABLE_CP1" in os.environ: - capabilities = None - else: - capabilities = ["CP1"] # able to handle CAE claims challenges - self._msal_app = self._create_app(msal.PublicClientApplication, client_capabilities=capabilities) - return self._msal_app - @abc.abstractmethod def _request_token(self, *scopes, **kwargs): pass diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index a03209e935fb..3732f1595313 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -2,21 +2,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import abc import os import msal from .msal_client import MsalClient from .._constants import EnvironmentVariables -from .._internal import get_default_authority, normalize_authority, validate_tenant_id +from .._internal import get_default_authority, normalize_authority, resolve_tenant, validate_tenant_id from .._persistent_cache import _load_persistent_cache -try: - ABC = abc.ABC -except AttributeError: # Python 2.7, abc exists, but not ABC - ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore - try: from typing import TYPE_CHECKING except ImportError: @@ -24,14 +18,14 @@ if TYPE_CHECKING: # pylint:disable=ungrouped-imports,unused-import - from typing import Any, Optional, Type, Union + from typing import Any, Dict, Optional, Union -class MsalCredential(ABC): +class MsalCredential(object): """Base class for credentials wrapping MSAL applications""" def __init__(self, client_id, client_credential=None, **kwargs): - # type: (str, Optional[Union[str, dict]], **Any) -> None + # type: (str, Optional[Union[str, Dict]], **Any) -> None authority = kwargs.pop("authority", None) self._authority = normalize_authority(authority) if authority else get_default_authority() self._regional_authority = kwargs.pop( @@ -39,7 +33,10 @@ def __init__(self, client_id, client_credential=None, **kwargs): ) self._tenant_id = kwargs.pop("tenant_id", None) or "organizations" validate_tenant_id(self._tenant_id) + self._allow_multitenant = kwargs.pop("allow_multitenant_authentication", False) + self._client = MsalClient(**kwargs) + self._client_applications = {} # type: Dict[str, msal.ClientApplication] self._client_credential = client_credential self._client_id = client_id @@ -51,27 +48,23 @@ def __init__(self, client_id, client_credential=None, **kwargs): else: self._cache = msal.TokenCache() - self._client = MsalClient(**kwargs) - - # postpone creating the wrapped application because its initializer uses the network - self._msal_app = None # type: Optional[msal.ClientApplication] super(MsalCredential, self).__init__() - @abc.abstractmethod - def _get_app(self): - # type: () -> msal.ClientApplication - pass - - def _create_app(self, cls, **kwargs): - # type: (Type[msal.ClientApplication], **Any) -> msal.ClientApplication - app = cls( - client_id=self._client_id, - client_credential=self._client_credential, - authority="{}/{}".format(self._authority, self._tenant_id), - azure_region=self._regional_authority, - token_cache=self._cache, - http_client=self._client, - **kwargs - ) + def _get_app(self, **kwargs): + # type: (**Any) -> msal.ClientApplication + tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs) + if tenant_id not in self._client_applications: + # CP1 = can handle claims challenges (CAE) + capabilities = None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"] + cls = msal.ConfidentialClientApplication if self._client_credential else msal.PublicClientApplication + self._client_applications[tenant_id] = cls( + client_id=self._client_id, + client_credential=self._client_credential, + client_capabilities=capabilities, + authority="{}/{}".format(self._authority, tenant_id), + azure_region=self._regional_authority, + token_cache=self._cache, + http_client=self._client, + ) - return app + return self._client_applications[tenant_id] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py index de7703c7e570..e09f35006c1e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py @@ -42,7 +42,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await super().get_token(*scopes, **kwargs) - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": return self._client.get_cached_token(*scopes) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py index c8a8f01d8787..484abbb9f871 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py @@ -48,6 +48,10 @@ class AzureApplicationCredential(ChainedTokenCredential): `_ for an overview of managed identities. + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application or user is registered in. When False, which is the default, the credential will acquire tokens + only from the tenant specified by **AZURE_TENANT_ID**. This argument doesn't apply to managed identity + authentication. :keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com", the authority for Azure Public Cloud, which is the default when no value is given for this keyword argument or environment variable AZURE_AUTHORITY_HOST. :class:`~azure.identity.AzureAuthorityHosts` defines authorities for diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py index d80f59a9e973..225fbe434d94 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -10,25 +10,29 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Iterable, Optional + from typing import Any, Optional from azure.core.credentials import AccessToken class AuthorizationCodeCredential(AsyncContextManager, GetTokenMixin): """Authenticates by redeeming an authorization code previously obtained from Azure Active Directory. - See https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-auth-code-flow for more information + See `Azure Active Directory documentation + `_ for more information about the authentication flow. - :param str tenant_id: ID of the application's Azure Active Directory tenant. Also called its 'directory' ID. + :param str tenant_id: ID of the application's Azure Active Directory tenant. Also called its "directory" ID. :param str client_id: the application's client ID :param str authorization_code: the authorization code from the user's log-in :param str redirect_uri: The application's redirect URI. Must match the URI used to request the authorization code. - :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. + :keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com", + the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` + defines authorities for other clouds. :keyword str client_secret: One of the application's client secrets. Required only for web apps and web APIs. + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the user is registered in. When False, which is the default, the credential will acquire tokens only from the + user's home tenant or the tenant specified by **tenant_id**. """ async def __aenter__(self): @@ -62,15 +66,19 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": redeeming the authorization code. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. Any error response from Azure Active Directory is available as the error's ``response`` attribute. """ return await super().get_token(*scopes, **kwargs) - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": - return self._client.get_cached_access_token(scopes) + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + return self._client.get_cached_access_token(scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": if self._authorization_code: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py index 2445a678bb74..cb5c14bb26d4 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py @@ -64,7 +64,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await super().get_token(*scopes, **kwargs) - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": return self._client.get_cached_token(*scopes) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py index 2db249d9f750..d01a3619462d 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py @@ -20,7 +20,7 @@ parse_token, sanitize_output, ) -from ..._internal import _scopes_to_resource +from ..._internal import _scopes_to_resource, resolve_tenant if TYPE_CHECKING: from typing import Any @@ -31,8 +31,15 @@ class AzureCliCredential(AsyncContextManager): """Authenticates by requesting a token from the Azure CLI. This requires previously logging in to Azure via "az login", and will use the CLI's currently logged in identity. + + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the identity logged in to the Azure CLI is registered in. When False, which is the default, the credential will + acquire tokens only from the tenant of the Azure CLI's active subscription. """ + def __init__(self, **kwargs: "Any") -> None: + self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False) + @log_get_token_async async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": """Request an access token for `scopes`. @@ -41,6 +48,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": also handle token caching because this credential doesn't cache the tokens it acquires. :param str scopes: desired scope for the access token. This credential allows only one scope per request. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke the Azure CLI. @@ -52,7 +62,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return _SyncAzureCliCredential().get_token(*scopes, **kwargs) resource = _scopes_to_resource(*scopes) - output = await _run_command(COMMAND_LINE.format(resource)) + command = COMMAND_LINE.format(resource) + tenant = resolve_tenant("", self._allow_multitenant, **kwargs) + if tenant: + command += " --tenant " + tenant + output = await _run_command(command) token = parse_token(output) if not token: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py index 8ce2976e29a5..c99433bb13d4 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py @@ -16,6 +16,7 @@ raise_for_error, parse_token, ) +from ..._internal import resolve_tenant if TYPE_CHECKING: # pylint:disable=ungrouped-imports @@ -27,8 +28,15 @@ class AzurePowerShellCredential(AsyncContextManager): """Authenticates by requesting a token from Azure PowerShell. This requires previously logging in to Azure via "Connect-AzAccount", and will use the currently logged in identity. + + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the identity logged in to Azure PowerShell is registered in. When False, which is the default, the credential + will acquire tokens only from the tenant of Azure PowerShell's active subscription. """ + def __init__(self, **kwargs: "Any") -> None: + self._allow_multitenant = kwargs.get("allow_multitenant_authentication", False) + @log_get_token_async async def get_token( self, *scopes: str, **kwargs: "Any" @@ -39,6 +47,9 @@ async def get_token( also handle token caching because this credential doesn't cache the tokens it acquires. :param str scopes: desired scope for the access token. This credential allows only one scope per request. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke Azure PowerShell, or @@ -50,7 +61,8 @@ async def get_token( if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): return _SyncCredential().get_token(*scopes, **kwargs) - command_line = get_command_line(scopes) + tenant_id = resolve_tenant("", self._allow_multitenant, **kwargs) + command_line = get_command_line(scopes, tenant_id) output = await run_command_line(command_line) token = parse_token(output) return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py index db78d7e263ae..a78b9b790eac 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py @@ -40,6 +40,9 @@ class CertificateCredential(AsyncContextManager, GetTokenMixin): :keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential will cache tokens in memory. :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application is registered in. When False, which is the default, the credential will acquire tokens only from + the tenant specified by **tenant_id**. """ def __init__(self, tenant_id, client_id, certificate_path=None, **kwargs): @@ -71,8 +74,8 @@ async def close(self): await self._client.__aexit__() - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": - return self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + return self._client.get_cached_access_token(scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py index d77ecc1030a8..676e0b15e790 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py @@ -29,6 +29,9 @@ class ClientSecretCredential(AsyncContextManager, GetTokenMixin): :keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential will cache tokens in memory. :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application is registered in. When False, which is the default, the credential will acquire tokens only from + the tenant specified by **tenant_id**. """ def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs: "Any") -> None: @@ -62,8 +65,8 @@ async def close(self): await self._client.__aexit__() - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": - return self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + return self._client.get_cached_access_token(scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py index 9685d5acce32..512fc621a21b 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py @@ -47,7 +47,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": ) return await super(CloudShellCredential, self).get_token(*scopes, **kwargs) - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": return self._client.get_cached_token(*scopes) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index cf5085556b9e..8888e5d28874 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -42,9 +42,12 @@ class DefaultAzureCredential(ChainedTokenCredential): This default behavior is configurable with keyword arguments. + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application is registered in. When False, which is the default, the credential will acquire tokens only from + its configured tenant. This argument doesn't apply to managed identity authentication. :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. Managed identities ignore this because they reside in a single cloud. + the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` + defines authorities for other clouds. Managed identities ignore this because they reside in a single cloud. :keyword bool exclude_cli_credential: Whether to exclude the Azure CLI from the credential. Defaults to **False**. :keyword bool exclude_environment_credential: Whether to exclude a service principal configured by environment variables from the credential. Defaults to **False**. @@ -73,7 +76,7 @@ def __init__(self, **kwargs: "Any") -> None: vscode_tenant_id = kwargs.pop( "visual_studio_code_tenant_id", os.environ.get(EnvironmentVariables.AZURE_TENANT_ID) ) - vscode_args = {} + vscode_args = dict(kwargs) if authority: vscode_args["authority"] = authority if vscode_tenant_id: @@ -118,9 +121,9 @@ def __init__(self, **kwargs: "Any") -> None: if not exclude_visual_studio_code_credential: credentials.append(VisualStudioCodeCredential(**vscode_args)) if not exclude_cli_credential: - credentials.append(AzureCliCredential()) + credentials.append(AzureCliCredential(**kwargs)) if not exclude_powershell_credential: - credentials.append(AzurePowerShellCredential()) + credentials.append(AzurePowerShellCredential(**kwargs)) super().__init__(*credentials) @@ -130,6 +133,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + + :rtype: :class:`azure.core.credentials.AccessToken` + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a `message` attribute listing each authentication attempt and its error message. """ diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py index ac6951d9b9d8..a1e84100b005 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py @@ -37,6 +37,10 @@ class EnvironmentCredential(AsyncContextManager): - **AZURE_CLIENT_ID**: the service principal's client ID - **AZURE_CLIENT_CERTIFICATE_PATH**: path to a PEM-encoded certificate file including the private key. The certificate must not be password-protected. + + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the application or user is registered in. When False, which is the default, the credential will acquire tokens + only from the tenant specified by **AZURE_TENANT_ID**. """ def __init__(self, **kwargs: "Any") -> None: @@ -87,7 +91,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises ~azure.identity.CredentialUnavailableError: environment variable configuration is incomplete """ if not self._credential: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py index a562ac1d46a6..9274adf6aea5 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py @@ -38,7 +38,7 @@ async def __aenter__(self): async def close(self) -> None: await self._client.close() - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": return self._client.get_cached_token(*scopes) async def _request_token(self, *scopes, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py index 53e476d6b171..5e8de07d763e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py @@ -42,7 +42,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await super().get_token(*scopes, **kwargs) - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": return self._client.get_cached_token(*scopes) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index 30f50b937d44..b663f16623af 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -32,6 +32,9 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncContextManager): :keyword cache_persistence_options: configuration for persistent token caching. If not provided, the credential will use the persistent cache shared by Microsoft development applications :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the user is registered in. When False, which is the default, the credential will acquire tokens only from the + user's home tenant. """ async def __aenter__(self): @@ -54,7 +57,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user information :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` @@ -78,7 +85,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account): - token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token) + token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) return token raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username"))) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py index e180090d40d8..586354f8ff30 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py @@ -26,6 +26,9 @@ class VisualStudioCodeCredential(_VSCodeCredentialBase, AsyncContextManager, Get :keyword str tenant_id: ID of the tenant the credential should authenticate in. Defaults to the "Azure: Tenant" setting in VS Code's user settings or, when that setting has no value, the "organizations" tenant, which supports only Azure Active Directory work or school accounts. + :keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant + the user is registered in. When False, which is the default, the credential will acquire tokens only from the + user's home tenant or the tenant configured by **tenant_id** or VS Code's user settings. """ async def __aenter__(self) -> "VisualStudioCodeCredential": @@ -45,9 +48,13 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises ~azure.identity.CredentialUnavailableError: the credential cannot retrieve user details from Visual - Studio Code + Studio Code """ if self._unavailable_reason: raise CredentialUnavailableError(message=self._unavailable_reason) @@ -56,9 +63,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await super().get_token(*scopes, **kwargs) - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": self._client = cast(AadClient, self._client) - return self._client.get_cached_access_token(scopes) + return self._client.get_cached_access_token(scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": refresh_token = self._get_refresh_token() diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 9237eb4ef855..76c8c8af5bab 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -52,7 +52,7 @@ async def obtain_token_by_authorization_code( **kwargs: "Any" ) -> "AccessToken": request = self._get_auth_code_request( - scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret + scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) now = int(time.time()) response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs) @@ -60,7 +60,7 @@ async def obtain_token_by_authorization_code( async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs): # type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken - request = self._get_client_certificate_request(scopes, certificate) + request = self._get_client_certificate_request(scopes, certificate, **kwargs) now = int(time.time()) response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs) return self._process_response(response, now) @@ -68,7 +68,7 @@ async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs async def obtain_token_by_client_secret( self, scopes: "Iterable[str]", secret: str, **kwargs: "Any" ) -> "AccessToken": - request = self._get_client_secret_request(scopes, secret) + request = self._get_client_secret_request(scopes, secret, **kwargs) now = int(time.time()) response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs) return self._process_response(response, now) @@ -76,7 +76,7 @@ async def obtain_token_by_client_secret( async def obtain_token_by_refresh_token( self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any" ) -> "AccessToken": - request = self._get_refresh_token_request(scopes, refresh_token) + request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) now = int(time.time()) response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs) return self._process_response(response, now) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py index 204cffff1d6d..5da662f4e2b4 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py @@ -25,7 +25,7 @@ def __init__(self, *args: "Any", **kwargs: "Any") -> None: super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @abc.abstractmethod - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": """Attempt to acquire an access token from a cache or by redeeming a refresh token""" @abc.abstractmethod @@ -46,20 +46,24 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. + :keyword str tenant_id: optional tenant to include in the token request. If **allow_multitenant_authentication** + is False, specifying a tenant with this argument may raise an exception. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks - required data, state, or platform support + required data, state, or platform support :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` - attribute gives a reason. + attribute gives a reason. """ if not scopes: raise ValueError('"get_token" requires at least one scope') try: - token = await self._acquire_token_silently(*scopes) + token = await self._acquire_token_silently(*scopes, **kwargs) if not token: self._last_request_time = int(time.time()) - token = await self._request_token(*scopes) + token = await self._request_token(*scopes, **kwargs) elif self._should_refresh(token): try: self._last_request_time = int(time.time()) diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 2805c54f3135..b1be22602c72 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -153,8 +153,14 @@ def mock_response(status_code=200, headers=None, json_payload=None): def get_discovery_response(endpoint="https://a/b"): + """Get a mock response containing the values MSAL requires from tenant and instance discovery. + + The response is incomplete and its values aren't necessarily valid, particularly for instance discovery, but it's + sufficient. MSAL will send token requests to "{endpoint}/oauth2/v2.0/token_endpoint" after receiving a tenant + discovery response created by this method. + """ aad_metadata_endpoint_names = ("authorization_endpoint", "token_endpoint", "tenant_discovery_endpoint") - payload = {name: endpoint for name in aad_metadata_endpoint_names} + payload = {name: endpoint + "/oauth2/v2.0/" + name for name in aad_metadata_endpoint_names} payload["metadata"] = "" return mock_response(json_payload=payload) diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 4a313608a3b1..9f798d50c4fe 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -56,7 +56,11 @@ def test_exceptions_do_not_expose_secrets(): fns = [ functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), - functools.partial(client.obtain_token_by_refresh_token, "refresh token", ("scope"),), + functools.partial( + client.obtain_token_by_refresh_token, + "refresh token", + ("scope"), + ), ] def assert_secrets_not_exposed(): @@ -233,3 +237,76 @@ def test_retries_token_requests(): client.obtain_token_by_refresh_token("", "") assert transport.send.call_count > 1 transport.send.reset_mock() + + +def test_shared_cache(): + """The client should return only tokens associated with its own client_id""" + + client_id_a = "client-id-a" + client_id_b = "client-id-b" + scope = "scope" + expected_token = "***" + tenant_id = "tenant" + authority = "https://localhost/" + tenant_id + + cache = TokenCache() + cache.add( + { + "response": build_aad_response(access_token=expected_token), + "client_id": client_id_a, + "scope": [scope], + "token_endpoint": "/".join((authority, tenant_id, "oauth2/v2.0/token")), + } + ) + + common_args = dict(authority=authority, cache=cache, tenant_id=tenant_id) + client_a = AadClient(client_id=client_id_a, **common_args) + client_b = AadClient(client_id=client_id_b, **common_args) + + # A has a cached token + token = client_a.get_cached_access_token([scope]) + assert token.token == expected_token + + # which B shouldn't return + assert client_b.get_cached_access_token([scope]) is None + + +def test_multitenant_cache(): + client_id = "client-id" + scope = "scope" + expected_token = "***" + tenant_a = "tenant-a" + tenant_b = "tenant-b" + tenant_c = "tenant-c" + authority = "https://localhost/" + tenant_a + + cache = TokenCache() + cache.add( + { + "response": build_aad_response(access_token=expected_token), + "client_id": client_id, + "scope": [scope], + "token_endpoint": "/".join((authority, tenant_a, "oauth2/v2.0/token")), + } + ) + + common_args = dict(authority=authority, cache=cache, client_id=client_id) + client_a = AadClient(tenant_id=tenant_a, **common_args) + client_b = AadClient(tenant_id=tenant_b, **common_args) + + # A has a cached token + token = client_a.get_cached_access_token([scope]) + assert token.token == expected_token + + # which B shouldn't return + assert client_b.get_cached_access_token([scope]) is None + + # but C allows multitenant auth and should therefore return the token from tenant_a when appropriate + client_c = AadClient(tenant_id=tenant_c, allow_multitenant_authentication=True, **common_args) + assert client_c.get_cached_access_token([scope]) is None + token = client_c.get_cached_access_token([scope], tenant_id=tenant_a) + assert token.token == expected_token + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + assert client_c.get_cached_access_token([scope], tenant_id=tenant_a) is None diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index 81d3fc2d827f..64212d573e97 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -5,13 +5,11 @@ import functools from unittest.mock import Mock, patch from urllib.parse import urlparse -import time from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError -from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY +from azure.identity._constants import EnvironmentVariables from azure.identity._internal import AadClientCertificate from azure.identity.aio._internal.aad_client import AadClient -from azure.core.credentials import AccessToken from msal import TokenCache import pytest @@ -241,3 +239,76 @@ async def test_retries_token_requests(): await client.obtain_token_by_refresh_token("", "") assert transport.send.call_count > 1 transport.send.reset_mock() + + +async def test_shared_cache(): + """The client should return only tokens associated with its own client_id""" + + client_id_a = "client-id-a" + client_id_b = "client-id-b" + scope = "scope" + expected_token = "***" + tenant_id = "tenant" + authority = "https://localhost/" + tenant_id + + cache = TokenCache() + cache.add( + { + "response": build_aad_response(access_token=expected_token), + "client_id": client_id_a, + "scope": [scope], + "token_endpoint": "/".join((authority, tenant_id, "oauth2/v2.0/token")), + } + ) + + common_args = dict(authority=authority, cache=cache, tenant_id=tenant_id) + client_a = AadClient(client_id=client_id_a, **common_args) + client_b = AadClient(client_id=client_id_b, **common_args) + + # A has a cached token + token = client_a.get_cached_access_token([scope]) + assert token.token == expected_token + + # which B shouldn't return + assert client_b.get_cached_access_token([scope]) is None + + +async def test_multitenant_cache(): + client_id = "client-id" + scope = "scope" + expected_token = "***" + tenant_a = "tenant-a" + tenant_b = "tenant-b" + tenant_c = "tenant-c" + authority = "https://localhost/" + tenant_a + + cache = TokenCache() + cache.add( + { + "response": build_aad_response(access_token=expected_token), + "client_id": client_id, + "scope": [scope], + "token_endpoint": "/".join((authority, tenant_a, "oauth2/v2.0/token")), + } + ) + + common_args = dict(authority=authority, cache=cache, client_id=client_id) + client_a = AadClient(tenant_id=tenant_a, **common_args) + client_b = AadClient(tenant_id=tenant_b, **common_args) + + # A has a cached token + token = client_a.get_cached_access_token([scope]) + assert token.token == expected_token + + # which B shouldn't return + assert client_b.get_cached_access_token([scope]) is None + + # but C allows multitenant auth and should therefore return the token from tenant_a when appropriate + client_c = AadClient(tenant_id=tenant_c, allow_multitenant_authentication=True, **common_args) + assert client_c.get_cached_access_token([scope]) is None + token = client_c.get_cached_access_token([scope], tenant_id=tenant_a) + assert token.token == expected_token + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + assert client_c.get_cached_access_token([scope], tenant_id=tenant_a) is None diff --git a/sdk/identity/azure-identity/tests/test_auth_code.py b/sdk/identity/azure-identity/tests/test_auth_code.py index 7b6ce76a75ed..29ab3733a633 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code.py +++ b/sdk/identity/azure-identity/tests/test_auth_code.py @@ -2,19 +2,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import AuthorizationCodeCredential +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT import msal import pytest +from six.moves.urllib_parse import urlparse from helpers import build_aad_response, mock_response, Request, validating_transport try: - from unittest.mock import Mock + from unittest.mock import Mock, patch except ImportError: # python < 3.3 - from mock import Mock # type: ignore + from mock import Mock, patch # type: ignore def test_no_scopes(): @@ -114,3 +116,73 @@ def test_auth_code_credential(): token = credential.get_token(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 2 + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**")) + + credential = AuthorizationCodeCredential( + first_tenant, + "client-id", + "authcode", + "https://localhost", + allow_multitenant_authentication=True, + transport=Mock(send=send), + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**")) + + credential = AuthorizationCodeCredential( + expected_tenant, "client-id", "authcode", "https://localhost", transport=Mock(send=send) + ) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_auth_code_async.py b/sdk/identity/azure-identity/tests/test_auth_code_async.py index 3b754b40a6f4..5eb55e6515cc 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code_async.py +++ b/sdk/identity/azure-identity/tests/test_auth_code_async.py @@ -2,9 +2,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from unittest.mock import Mock +from unittest.mock import Mock, patch +from urllib.parse import urlparse +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import AuthorizationCodeCredential import msal @@ -137,3 +140,73 @@ async def test_auth_code_credential(): token = await credential.get_token(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 2 + + +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**")) + + credential = AuthorizationCodeCredential( + first_tenant, + "client-id", + "authcode", + "https://localhost", + allow_multitenant_authentication=True, + transport=Mock(send=send), + ) + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +async def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**")) + + credential = AuthorizationCodeCredential( + expected_tenant, "client-id", "authcode", "https://localhost", transport=Mock(send=send) + ) + + token = await credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = await credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index 666cec214860..38272fdb47ec 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -5,6 +5,7 @@ import json import os +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import CertificateCredential, RegionalAuthority, TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -343,3 +344,83 @@ def test_certificate_arguments(): CertificateCredential("tenant-id", "client-id", certificate_path="...", certificate_data="...") message = str(ex.value) assert "certificate_data" in message and "certificate_path" in message + + +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +def test_allow_multitenant_authentication(cert_path, cert_password): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant, "common"), 'unexpected tenant "{}"'.format(tenant) + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = CertificateCredential( + first_tenant, + "client-id", + cert_path, + password=cert_password, + allow_multitenant_authentication=True, + transport=Mock(send=send), + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +def test_multitenant_authentication_backcompat(cert_path, cert_password): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def send(request, **_): + parsed = urlparse(request.url) + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, expected_tenant)) + + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = CertificateCredential( + expected_tenant, "client-id", cert_path, password=cert_password, transport=Mock(send=send) + ) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict( + os.environ, {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py index 001308460ead..fbfaa562f157 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, patch from urllib.parse import urlparse +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -14,7 +15,7 @@ from msal import TokenCache import pytest -from helpers import build_aad_response, urlsafeb64_decode, mock_response, Request +from helpers import build_aad_response, mock_response, Request from helpers_async import async_validating_transport, AsyncMockTransport from test_certificate_credential import BOTH_CERTS, CERT_PATH, EC_CERT_PATH, validate_jwt @@ -265,3 +266,79 @@ def test_certificate_arguments(): CertificateCredential("tenant-id", "client-id", certificate_path="...", certificate_data="...") message = str(ex.value) assert "certificate_data" in message and "certificate_path" in message + + +@pytest.mark.asyncio +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +async def test_allow_multitenant_authentication(cert_path, cert_password): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = CertificateCredential( + first_tenant, + "client-id", + cert_path, + password=cert_password, + allow_multitenant_authentication=True, + transport=Mock(send=send), + ) + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +@pytest.mark.asyncio +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +async def test_multitenant_authentication_backcompat(cert_path, cert_password): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = CertificateCredential( + expected_tenant, "client-id", cert_path, password=cert_password, transport=Mock(send=send) + ) + + token = await credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = await credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_cli_credential.py b/sdk/identity/azure-identity/tests/test_cli_credential.py index eb0bb106125a..bac97fd4ac7c 100644 --- a/sdk/identity/azure-identity/tests/test_cli_credential.py +++ b/sdk/identity/azure-identity/tests/test_cli_credential.py @@ -4,9 +4,11 @@ # ------------------------------------ from datetime import datetime import json +import re import sys from azure.identity import AzureCliCredential, CredentialUnavailableError +from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.azure_cli import CLI_NOT_FOUND, NOT_LOGGED_IN from azure.core.exceptions import ClientAuthenticationError @@ -148,3 +150,79 @@ def test_timeout(): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=TimeoutExpired("", 42))): with pytest.raises(CredentialUnavailableError): AzureCliCredential().get_token("scope") + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + default_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def fake_check_output(command_line, **_): + match = re.search("--tenant (.*)", command_line[-1]) + tenant = match.groups()[0] if match else default_tenant + assert tenant in (default_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + return json.dumps( + { + "expiresOn": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + "accessToken": first_token if tenant == default_tenant else second_token, + "subscription": "some-guid", + "tenant": tenant, + "tokenType": "Bearer", + } + ) + + credential = AzureCliCredential(allow_multitenant_authentication=True) + with mock.patch(CHECK_OUTPUT, fake_check_output): + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=default_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def fake_check_output(command_line, **_): + match = re.search("--tenant (.*)", command_line[-1]) + assert match is None or match[1] == expected_tenant + return json.dumps( + { + "expiresOn": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + "accessToken": expected_token, + "subscription": "some-guid", + "tenant": expected_token, + "tokenType": "Bearer", + } + ) + + credential = AzureCliCredential() + with mock.patch(CHECK_OUTPUT, fake_check_output): + token = credential.get_token("scope") + assert token.token == expected_token + + # specifying a tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with mock.patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"} + ): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert ( + token.token == expected_token + ), "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_cli_credential_async.py b/sdk/identity/azure-identity/tests/test_cli_credential_async.py index 69b1f8c1d41f..d5f5885f5d1f 100644 --- a/sdk/identity/azure-identity/tests/test_cli_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_cli_credential_async.py @@ -5,11 +5,13 @@ import asyncio from datetime import datetime import json +import re import sys from unittest import mock from azure.identity import CredentialUnavailableError from azure.identity.aio import AzureCliCredential +from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.azure_cli import CLI_NOT_FOUND, NOT_LOGGED_IN from azure.core.exceptions import ClientAuthenticationError import pytest @@ -181,3 +183,79 @@ async def test_timeout(): assert proc.communicate.call_count == 1 assert proc.kill.call_count == 1 + + +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + default_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def fake_exec(*args, **_): + match = re.search("--tenant (.*)", args[-1]) + tenant = match[1] if match else default_tenant + assert tenant in (default_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + output = json.dumps( + { + "expiresOn": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + "accessToken": first_token if tenant == default_tenant else second_token, + "subscription": "some-guid", + "tenant": tenant, + "tokenType": "Bearer", + } + ).encode() + return mock.Mock(communicate=mock.Mock(return_value=get_completed_future((output, b""))), returncode=0) + + credential = AzureCliCredential(allow_multitenant_authentication=True) + with mock.patch(SUBPROCESS_EXEC, fake_exec): + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=default_tenant) + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +async def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def fake_exec(*args, **_): + match = re.search("--tenant (.*)", args[-1]) + assert match is None or match[1] == expected_tenant + output = json.dumps( + { + "expiresOn": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + "accessToken": expected_token, + "subscription": "some-guid", + "tenant": expected_token, + "tokenType": "Bearer", + } + ).encode() + return mock.Mock(communicate=mock.Mock(return_value=get_completed_future((output, b""))), returncode=0) + + credential = AzureCliCredential() + with mock.patch(SUBPROCESS_EXEC, fake_exec): + token = await credential.get_token("scope") + assert token.token == expected_token + + # specifying a tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert ( + token.token == expected_token + ), "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential.py b/sdk/identity/azure-identity/tests/test_client_secret_credential.py index ded3c9727e1d..854990f232b6 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import ClientSecretCredential, RegionalAuthority, TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -10,7 +11,7 @@ import pytest from six.moves.urllib_parse import urlparse -from helpers import build_aad_response, mock_response, msal_validating_transport, Request +from helpers import build_aad_response, get_discovery_response, mock_response, msal_validating_transport, Request try: from unittest.mock import Mock, patch @@ -208,3 +209,72 @@ def test_cache_multiple_clients(): assert transport_b.send.call_count == 3 assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2 + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant, "common"), 'unexpected tenant "{}"'.format(tenant) + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = ClientSecretCredential( + first_tenant, "client-id", "secret", allow_multitenant_authentication=True, transport=Mock(send=send) + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def send(request, **_): + parsed = urlparse(request.url) + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, expected_tenant)) + + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = ClientSecretCredential(expected_tenant, "client-id", "secret", transport=Mock(send=send)) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py index 96d4366828d6..e20ec1317342 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -247,3 +248,68 @@ async def test_cache_multiple_clients(): assert transport_b.send.call_count == 1 assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2 + + +@pytest.mark.asyncio +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = ClientSecretCredential( + first_tenant, "client-id", "secret", allow_multitenant_authentication=True, transport=Mock(send=send) + ) + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +@pytest.mark.asyncio +async def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = ClientSecretCredential(expected_tenant, "client-id", "secret", transport=Mock(send=send)) + + token = await credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = await credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index 45a9519f3ee2..31c2ba1a71b4 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -314,7 +314,13 @@ def test_managed_identity_client_id(): def get_credential_for_shared_cache_test(expected_refresh_token, expected_access_token, cache, **kwargs): exclude_other_credentials = { - option: True for option in ("exclude_environment_credential", "exclude_managed_identity_credential") + option: True + for option in ( + "exclude_cli_credential", + "exclude_environment_credential", + "exclude_managed_identity_credential", + "exclude_powershell_credential", + ) } options = dict(exclude_other_credentials, **kwargs) @@ -356,3 +362,38 @@ def validate_tenant_id(credential): exclude_interactive_browser_credential=False, interactive_browser_tenant_id=tenant_id ) validate_tenant_id(mock_credential) + + +@pytest.mark.parametrize("expected_value", (True, False)) +def test_allow_multitenant_authentication(expected_value): + """the credential should pass "allow_multitenant_authentication" to the inner credentials which support it""" + + inner_credentials = { + credential: Mock() + for credential in ( + "AzureCliCredential", + "AzurePowerShellCredential", + "EnvironmentCredential", + "InteractiveBrowserCredential", + "ManagedIdentityCredential", # will ignore the argument + "SharedTokenCacheCredential", + ) + } + with patch.multiple(DefaultAzureCredential.__module__, **inner_credentials): + DefaultAzureCredential( + allow_multitenant_authentication=expected_value, exclude_interactive_browser_credential=False + ) + + for credential_name, mock_credential in inner_credentials.items(): + assert mock_credential.call_count == 1 + _, kwargs = mock_credential.call_args + + assert "allow_multitenant_authentication" in kwargs, ( + '"allow_multitenant_authentication" was not passed to ' + credential_name + ) + assert kwargs["allow_multitenant_authentication"] == expected_value + + +def test_unexpected_kwarg(): + """the credential shouldn't raise when given an unexpected keyword argument""" + DefaultAzureCredential(foo=42) diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index 1fd991244c02..0f144350640c 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from azure.identity.aio._credentials import vscode import os -from unittest import mock from unittest.mock import Mock, patch from urllib.parse import urlparse @@ -312,3 +310,35 @@ def get_credential_for_shared_cache_test(expected_refresh_token, expected_access # this credential uses a mock shared cache, so it works on all platforms with patch.object(SharedTokenCacheCredential, "supported", lambda: True): return DefaultAzureCredential(_cache=cache, transport=transport, **exclude_other_credentials, **kwargs) + + +@pytest.mark.parametrize("expected_value", (True, False)) +def test_allow_multitenant_authentication(expected_value): + """the credential should pass "allow_multitenant_authentication" to the inner credentials which support it""" + + inner_credentials = { + credential: Mock() + for credential in ( + "AzureCliCredential", + "AzurePowerShellCredential", + "EnvironmentCredential", + "ManagedIdentityCredential", # will ignore the argument + "SharedTokenCacheCredential", + ) + } + with patch.multiple(DefaultAzureCredential.__module__, **inner_credentials): + DefaultAzureCredential(allow_multitenant_authentication=expected_value) + + for credential_name, mock_credential in inner_credentials.items(): + assert mock_credential.call_count == 1 + _, kwargs = mock_credential.call_args + + assert "allow_multitenant_authentication" in kwargs, ( + '"allow_multitenant_authentication" was not passed to ' + credential_name + ) + assert kwargs["allow_multitenant_authentication"] == expected_value + + +def test_unexpected_kwarg(): + """the credential shouldn't raise when given an unexpected keyword argument""" + DefaultAzureCredential(foo=42) diff --git a/sdk/identity/azure-identity/tests/test_interactive_credential.py b/sdk/identity/azure-identity/tests/test_interactive_credential.py index d3a0c8f52881..6947dfe20d8b 100644 --- a/sdk/identity/azure-identity/tests/test_interactive_credential.py +++ b/sdk/identity/azure-identity/tests/test_interactive_credential.py @@ -10,15 +10,16 @@ CredentialUnavailableError, TokenCachePersistenceOptions, ) -from azure.identity._internal import InteractiveCredential +from azure.identity._internal import EnvironmentVariables, InteractiveCredential import pytest +from six.moves.urllib_parse import urlparse try: from unittest.mock import Mock, patch except ImportError: # python < 3.3 from mock import Mock, patch # type: ignore -from helpers import build_aad_response, build_id_token, id_token_claims +from helpers import build_aad_response, get_discovery_response, id_token_claims # fake object for tests which need to exercise request_token but don't care about its return value @@ -41,24 +42,14 @@ class MockCredential(InteractiveCredential): Default instances have an empty in-memory cache, and raise rather than send an HTTP request. """ - def __init__( - self, client_id="...", request_token=None, msal_app_factory=None, transport=None, **kwargs - ): - self._msal_app_factory = msal_app_factory + def __init__(self, client_id="...", request_token=None, transport=None, **kwargs): self._request_token_impl = request_token or Mock() transport = transport or Mock(send=Mock(side_effect=Exception("credential shouldn't send a request"))) - super(MockCredential, self).__init__( - client_id=client_id, transport=transport, **kwargs - ) + super(MockCredential, self).__init__(client_id=client_id, transport=transport, **kwargs) def _request_token(self, *scopes, **kwargs): return self._request_token_impl(*scopes, **kwargs) - def _get_app(self): - if self._msal_app_factory: - return self._create_app(self._msal_app_factory) - return super(MockCredential, self)._get_app() - def test_no_scopes(): """The credential should raise when get_token is called with no scopes""" @@ -79,14 +70,13 @@ def validate_app_parameters(authority, client_id, **_): assert client_id == record.client_id return Mock(get_accounts=Mock(return_value=[])) - app_factory = Mock(wraps=validate_app_parameters) - credential = MockCredential( - authentication_record=record, disable_automatic_authentication=True, msal_app_factory=app_factory, - ) + mock_client_application = Mock(wraps=validate_app_parameters) + credential = MockCredential(authentication_record=record, disable_automatic_authentication=True) with pytest.raises(AuthenticationRequiredError): - credential.get_token("scope") + with patch("msal.PublicClientApplication", mock_client_application): + credential.get_token("scope") - assert app_factory.call_count == 1, "credential didn't create an msal application" + assert mock_client_application.call_count == 1, "credential didn't create an msal application" def test_tenant_argument_overrides_record(): @@ -104,13 +94,11 @@ def validate_authority(authority, **_): return Mock(get_accounts=Mock(return_value=[])) credential = MockCredential( - authentication_record=record, - tenant_id=expected_tenant, - disable_automatic_authentication=True, - msal_app_factory=validate_authority, + authentication_record=record, tenant_id=expected_tenant, disable_automatic_authentication=True ) with pytest.raises(AuthenticationRequiredError): - credential.get_token("scope") + with patch("msal.PublicClientApplication", validate_authority): + credential.get_token("scope") def test_disable_automatic_authentication(): @@ -126,14 +114,14 @@ def test_disable_automatic_authentication(): credential = MockCredential( authentication_record=record, disable_automatic_authentication=True, - msal_app_factory=lambda *_, **__: msal_app, request_token=Mock(side_effect=Exception("credential shouldn't begin interactive authentication")), ) scope = "scope" expected_claims = "..." with pytest.raises(AuthenticationRequiredError) as ex: - credential.get_token(scope, claims=expected_claims) + with patch("msal.PublicClientApplication", lambda *_, **__: msal_app): + credential.get_token(scope, claims=expected_claims) # the exception should carry the requested scopes and claims, and any error message from AAD assert ex.value.scopes == (scope,) @@ -208,9 +196,10 @@ class CustomException(Exception): acquire_token_silent_with_error=Mock(side_effect=CustomException(expected_message)), get_accounts=Mock(return_value=[{"home_account_id": record.home_account_id}]), ) - credential = MockCredential(msal_app_factory=lambda *_, **__: msal_app, authentication_record=record) + credential = MockCredential(authentication_record=record) with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + with patch("msal.PublicClientApplication", lambda *_, **__: msal_app): + credential.get_token("scope") assert expected_message in ex.value.message assert msal_app.acquire_token_silent_with_error.call_count == 1, "credential didn't attempt silent auth" @@ -291,3 +280,95 @@ def _request_token(self, *_, **__): assert record.home_account_id == subject assert record.tenant_id == tenant assert record.username == username + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def request_token(*args, **kwargs): + tenant_id = kwargs.get("tenant_id") + return build_aad_response( + access_token=second_token if tenant_id == second_tenant else first_token, + id_token_claims=id_token_claims( + aud="...", + iss="http://localhost/tenant", + sub="subject", + preferred_username="...", + tenant_id="...", + object_id="...", + ), + ) + + def send(request, **_): + assert "/oauth2/v2.0/token" not in request.url, 'mock "request_token" should prevent sending a token request' + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + credential = MockCredential( + tenant_id=first_tenant, + allow_multitenant_authentication=True, + request_token=request_token, + transport=Mock(send=send), + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def request_token(*_, **__): + return build_aad_response( + access_token=expected_token, + id_token_claims=id_token_claims( + aud="...", + iss="http://localhost/tenant", + sub="subject", + preferred_username="...", + tenant_id="...", + object_id="...", + ), + ) + + def send(request, **_): + assert "/oauth2/v2.0/token" not in request.url, 'mock "request_token" should prevent sending a token request' + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + credential = MockCredential(tenant_id=expected_tenant, transport=Mock(send=send), request_token=request_token) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_powershell_credential.py b/sdk/identity/azure-identity/tests/test_powershell_credential.py index ec0a4ca7f23a..79a36f9a28a4 100644 --- a/sdk/identity/azure-identity/tests/test_powershell_credential.py +++ b/sdk/identity/azure-identity/tests/test_powershell_credential.py @@ -5,8 +5,10 @@ import base64 import logging from platform import python_version +import re import subprocess import sys +import time try: from unittest.mock import Mock, patch @@ -15,6 +17,7 @@ from azure.core.exceptions import ClientAuthenticationError from azure.identity import AzurePowerShellCredential, CredentialUnavailableError +from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.azure_powershell import ( AZ_ACCOUNT_NOT_INSTALLED, BLOCKED_BY_EXECUTION_POLICY, @@ -87,6 +90,7 @@ def test_get_token(stderr): encoded_script = command.split()[-1] decoded_script = base64.b64decode(encoded_script).decode("utf-16-le") + assert "TenantId" not in decoded_script assert "Get-AzAccessToken -ResourceUrl '{}'".format(scope) in decoded_script assert Popen().communicate.call_count == 1 @@ -243,3 +247,72 @@ def Popen(args, **kwargs): AzurePowerShellCredential().get_token("scope") assert Fake.calls == 2 + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def fake_Popen(command, **_): + assert command[-1].startswith("pwsh -NonInteractive -EncodedCommand ") + encoded_script = command[-1].split()[-1] + decoded_script = base64.b64decode(encoded_script).decode("utf-16-le") + match = re.search("Get-AzAccessToken -ResourceUrl '(\S+)'(?: -TenantId (\S+))?", decoded_script) + tenant = match.groups()[1] + + assert tenant is None or tenant == second_tenant, 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant is None else second_token + stdout = "azsdk%{}%{}".format(token, int(time.time()) + 3600) + + communicate = Mock(return_value=(stdout, "")) + return Mock(communicate=communicate, returncode=0) + + credential = AzurePowerShellCredential(allow_multitenant_authentication=True) + with patch(POPEN, fake_Popen): + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_token = "***" + + def fake_Popen(command, **_): + assert command[-1].startswith("pwsh -NonInteractive -EncodedCommand ") + encoded_script = command[-1].split()[-1] + decoded_script = base64.b64decode(encoded_script).decode("utf-16-le") + match = re.search("Get-AzAccessToken -ResourceUrl '(\S+)'(?: -TenantId (\S+))?", decoded_script) + tenant = match.groups()[1] + + assert tenant is None, "credential shouldn't accept an explicit tenant ID" + stdout = "azsdk%{}%{}".format(expected_token, int(time.time()) + 3600) + + communicate = Mock(return_value=(stdout, "")) + return Mock(communicate=communicate, returncode=0) + + credential = AzurePowerShellCredential() + with patch(POPEN, fake_Popen): + token = credential.get_token("scope") + assert token.token == expected_token + + # specifying a tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="some tenant") + + # ...unless the compat switch is enabled + with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = credential.get_token("scope", tenant_id="some tenant") + assert ( + token.token == expected_token + ), "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_powershell_credential_async.py b/sdk/identity/azure-identity/tests/test_powershell_credential_async.py index c04484242c7c..fec0b4c013de 100644 --- a/sdk/identity/azure-identity/tests/test_powershell_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_powershell_credential_async.py @@ -5,12 +5,15 @@ import asyncio import base64 import logging +import re import sys +import time from unittest.mock import Mock, patch from azure.core.exceptions import ClientAuthenticationError from azure.identity import CredentialUnavailableError from azure.identity.aio import AzurePowerShellCredential +from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.azure_powershell import ( AZ_ACCOUNT_NOT_INSTALLED, BLOCKED_BY_EXECUTION_POLICY, @@ -78,6 +81,7 @@ async def test_get_token(stderr): encoded_script = command.split()[-1] decoded_script = base64.b64decode(encoded_script).decode("utf-16-le") + assert "TenantId" not in decoded_script assert "Get-AzAccessToken -ResourceUrl '{}'".format(scope) in decoded_script assert mock_exec().result().communicate.call_count == 1 @@ -245,3 +249,73 @@ async def mock_exec(*args, **kwargs): await credential.get_token("scope") assert calls == 2 + + +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def fake_exec(*args, **_): + command = args[2] + assert command.startswith("pwsh -NonInteractive -EncodedCommand ") + encoded_script = command.split()[-1] + decoded_script = base64.b64decode(encoded_script).decode("utf-16-le") + match = re.search("Get-AzAccessToken -ResourceUrl '(\S+)'(?: -TenantId (\S+))?", decoded_script) + tenant = match[2] + + assert tenant is None or tenant == second_tenant, 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant is None else second_token + stdout = "azsdk%{}%{}".format(token, int(time.time()) + 3600) + + communicate = Mock(return_value=get_completed_future((stdout.encode(), b""))) + return Mock(communicate=communicate, returncode=0) + + credential = AzurePowerShellCredential(allow_multitenant_authentication=True) + with patch(CREATE_SUBPROCESS_EXEC, fake_exec): + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +async def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_token = "***" + + async def fake_exec(*args, **_): + command = args[2] + assert command.startswith("pwsh -NonInteractive -EncodedCommand ") + encoded_script = command.split()[-1] + decoded_script = base64.b64decode(encoded_script).decode("utf-16-le") + match = re.search("Get-AzAccessToken -ResourceUrl '(\S+)'(?: -TenantId (\S+))?", decoded_script) + tenant = match[2] + + assert tenant is None, "credential shouldn't accept an explicit tenant ID" + stdout = "azsdk%{}%{}".format(expected_token, int(time.time()) + 3600) + communicate = Mock(return_value=get_completed_future((stdout.encode(), b""))) + return Mock(communicate=communicate, returncode=0) + + credential = AzurePowerShellCredential() + with patch(CREATE_SUBPROCESS_EXEC, fake_exec): + token = await credential.get_token("scope") + assert token.token == expected_token + + # specifying a tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="some tenant") + + # ...unless the compat switch is enabled + with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = await credential.get_token("scope", tenant_id="some tenant") + assert ( + token.token == expected_token + ), "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index 12dec789f76f..41d512631706 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -6,6 +6,7 @@ from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import ( AuthenticationRecord, + AzureAuthorityHosts, CredentialUnavailableError, SharedTokenCacheCredential, ) @@ -32,6 +33,7 @@ build_aad_response, build_id_token, get_discovery_response, + id_token_claims, mock_response, msal_validating_transport, Request, @@ -569,7 +571,10 @@ def send(request, **_): cache = populated_cache( get_account_event( - "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, + "not-" + username, + "not-" + object_id, + "different-" + tenant_id, + client_id="not-" + client_id, ), ) credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=cache) @@ -738,25 +743,28 @@ def mock_send(request, **_): def test_client_capabilities(): """the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set""" + def send(request, **_): + # expecting only the discovery requests triggered by creating an msal.PublicClientApplication + # because the cache is empty--the credential shouldn't send a token request + return get_discovery_response("https://localhost/tenant") + record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") - transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) - credential = SharedTokenCacheCredential( - transport=transport, authentication_record=record, _cache=TokenCache() - ) + transport = Mock(send=send) + credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication: - credential._initialize() + with pytest.raises(ClientAuthenticationError): # (cache is empty) + credential.get_token("scope") assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] == ["CP1"] - credential = SharedTokenCacheCredential( - transport=transport, authentication_record=record, _cache=TokenCache() - ) + credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication: with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}): - credential._initialize() + with pytest.raises(ClientAuthenticationError): # (cache is empty) + credential.get_token("scope") assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args @@ -777,9 +785,7 @@ def test_claims_challenge(): ) transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) - credential = SharedTokenCacheCredential( - transport=transport, authentication_record=record, _cache=TokenCache() - ) + credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication", lambda *_, **__: msal_app): credential.get_token("scope", claims=expected_claims) @@ -788,11 +794,211 @@ def test_claims_challenge(): assert kwargs["claims_challenge"] == expected_claims +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + default_tenant = "organizations" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def send(request, **_): + parsed = urlparse(request.url) + tenant_id = parsed.path.split("/")[1] + assert tenant_id in (default_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant_id) + return mock_response( + json_payload=build_aad_response( + access_token=second_token if tenant_id == second_tenant else first_token, + id_token_claims=id_token_claims(aud="...", iss="...", sub="..."), + ) + ) + + authority = AzureAuthorityHosts.AZURE_PUBLIC_CLOUD + expected_account = get_account_event( + "user", "object-id", "tenant-id", authority=authority, client_id="client-id", refresh_token="**" + ) + cache = populated_cache(expected_account) + + credential = SharedTokenCacheCredential( + allow_multitenant_authentication=True, authority=authority, transport=Mock(send=send), _cache=cache + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=default_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + default_tenant = "organizations" + expected_token = "***" + + def send(request, **_): + parsed = urlparse(request.url) + tenant_id = parsed.path.split("/")[1] + assert tenant_id == default_tenant + return mock_response( + json_payload=build_aad_response( + access_token=expected_token, + id_token_claims=id_token_claims(aud="...", iss="...", sub="..."), + ) + ) + + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + username = "me" + + expected_account = get_account_event( + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token="**" + ) + cache = populated_cache(expected_account) + + credential = SharedTokenCacheCredential(authority=authority, transport=Mock(send=send), _cache=cache) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=default_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="some tenant") + + # ...unless the compat switch is enabled + with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = credential.get_token("scope", tenant_id="some tenant") + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" + + +def test_allow_multitenant_authentication_auth_record(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + default_tenant = "organizations" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + authority = AzureAuthorityHosts.AZURE_PUBLIC_CLOUD + object_id = "object-id" + home_account_id = object_id + "." + default_tenant + record = AuthenticationRecord(default_tenant, "client-id", authority, home_account_id, "user") + + def send(request, **_): + parsed = urlparse(request.url) + tenant_id = parsed.path.split("/")[1] + if "/oauth2/v2.0/token" not in request.url: + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant_id)) + + assert tenant_id in (default_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant_id) + return mock_response( + json_payload=build_aad_response( + access_token=second_token if tenant_id == second_tenant else first_token, + id_token_claims=id_token_claims(aud="...", iss="...", sub="..."), + ) + ) + + expected_account = get_account_event( + record.username, object_id, record.tenant_id, client_id=record.client_id, refresh_token="**" + ) + cache = populated_cache(expected_account) + + credential = SharedTokenCacheCredential( + allow_multitenant_authentication=True, + authority=authority, + transport=Mock(send=send), + authentication_record=record, + _cache=cache, + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=default_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_not_allowed_authentication_record(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + default_tenant = "organizations" + expected_token = "***" + + authority = AzureAuthorityHosts.AZURE_PUBLIC_CLOUD + object_id = "object-id" + home_account_id = object_id + "." + default_tenant + record = AuthenticationRecord(default_tenant, "client-id", authority, home_account_id, "user") + + def send(request, **_): + parsed = urlparse(request.url) + tenant_id = parsed.path.split("/")[1] + if "/oauth2/v2.0/token" not in request.url: + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant_id)) + + assert tenant_id == default_tenant + return mock_response( + json_payload=build_aad_response( + access_token=expected_token, + id_token_claims=id_token_claims(aud="...", iss="...", sub="..."), + ) + ) + + expected_account = get_account_event( + record.username, + object_id, + record.tenant_id, + authority=record.authority, + client_id=record.client_id, + refresh_token="**", + ) + cache = populated_cache(expected_account) + + credential = SharedTokenCacheCredential( + authority=authority, transport=Mock(send=send), authentication_record=record, _cache=cache + ) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=default_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="some tenant") + + # ...unless the compat switch is enabled + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = credential.get_token("scope", tenant_id="some tenant") + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" + + def get_account_event( username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None, **kwargs ): if authority: - endpoint = "https://" + "/".join((authority, utid, "path",)) + endpoint = "https://" + "/".join((authority, utid, "path")) else: endpoint = get_default_authority() + "/{}/{}".format(utid, "path") diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index ec4ca4c29153..a6d7f0d67d60 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -21,7 +21,7 @@ from msal import TokenCache import pytest -from helpers import build_aad_response, build_id_token, mock_response, Request +from helpers import build_aad_response, id_token_claims, mock_response, Request from helpers_async import async_validating_transport, AsyncMockTransport from test_shared_cache_credential import get_account_event, populated_cache @@ -603,3 +603,87 @@ async def test_initialization(): with pytest.raises(CredentialUnavailableError, match="Shared token cache unavailable"): await credential.get_token("scope") assert mock_cache_loader.call_count == 1 + + +@pytest.mark.asyncio +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def send(request, **_): + parsed = urlparse(request.url) + tenant_id = parsed.path.split("/")[1] + return mock_response( + json_payload=build_aad_response( + access_token=second_token if tenant_id == second_tenant else first_token, + id_token_claims=id_token_claims(aud="...", iss="...", sub="..."), + ) + ) + + authority = "localhost" + expected_account = get_account_event( + "user", "object-id", "tenant-id", authority=authority, client_id="client-id", refresh_token="**" + ) + cache = populated_cache(expected_account) + + credential = SharedTokenCacheCredential( + allow_multitenant_authentication=True, authority=authority, transport=Mock(send=send), _cache=cache + ) + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id="organizations") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +@pytest.mark.asyncio +async def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + default_tenant = "organizations" + expected_token = "***" + + async def send(request, **_): + parsed = urlparse(request.url) + tenant_id = parsed.path.split("/")[1] + assert tenant_id == default_tenant + return mock_response( + json_payload=build_aad_response( + access_token=expected_token, + id_token_claims=id_token_claims(aud="...", iss="...", sub="..."), + ) + ) + + authority = "localhost" + expected_account = get_account_event( + "user", "object-id", "tenant-id", authority=authority, client_id="client-id", refresh_token="**" + ) + cache = populated_cache(expected_account) + + credential = SharedTokenCacheCredential(authority=authority, transport=Mock(send=send), _cache=cache) + + token = await credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = await credential.get_token("scope", tenant_id=default_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="some tenant") + + # ...unless the compat switch is enabled + with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = await credential.get_token("scope", tenant_id="some tenant") + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index 0675d8da2547..ec06e41d14b4 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -6,6 +6,7 @@ import time from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError from azure.identity import AzureAuthorityHosts, CredentialUnavailableError, VisualStudioCodeCredential from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity._constants import EnvironmentVariables @@ -265,3 +266,69 @@ def test_no_user_settings(): credential.get_token("scope") assert transport.send.call_count == 1 + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = get_credential( + tenant_id=first_tenant, allow_multitenant_authentication=True, transport=mock.Mock(send=send) + ) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = get_credential(tenant_id=expected_tenant, transport=mock.Mock(send=send)) + + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py index 6afeb58f655a..d91332be0320 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError from azure.identity import AzureAuthorityHosts, CredentialUnavailableError from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT @@ -124,9 +125,7 @@ async def mock_send(request, **kwargs): assert request.body["refresh_token"] == expected_refresh_token return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) - credential = get_credential( - tenant_id=tenant_id, transport=mock.Mock(send=mock_send), authority=authority - ) + credential = get_credential(tenant_id=tenant_id, transport=mock.Mock(send=mock_send), authority=authority) with mock.patch(GET_REFRESH_TOKEN, return_value=expected_refresh_token): token = await credential.get_token("scope") assert token.token == access_token @@ -134,9 +133,7 @@ async def mock_send(request, **kwargs): # authority can be configured via environment variable with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = get_credential(tenant_id=tenant_id, transport=mock.Mock(send=mock_send)) - with mock.patch( - GET_REFRESH_TOKEN, return_value=expected_refresh_token - ): + with mock.patch(GET_REFRESH_TOKEN, return_value=expected_refresh_token): await credential.get_token("scope") assert token.token == access_token @@ -191,7 +188,7 @@ async def test_no_obtain_token_if_cached(): token_by_refresh_token = mock.Mock(return_value=expected_token) mock_client = mock.Mock( get_cached_access_token=mock.Mock(return_value=expected_token), - obtain_token_by_refresh_token=wrap_in_future(token_by_refresh_token) + obtain_token_by_refresh_token=wrap_in_future(token_by_refresh_token), ) credential = get_credential(_client=mock_client) @@ -258,3 +255,71 @@ async def test_no_user_settings(): await credential.get_token("scope") assert transport.send.call_count == 1 + + +@pytest.mark.asyncio +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = get_credential( + tenant_id=first_tenant, allow_multitenant_authentication=True, transport=mock.Mock(send=send) + ) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +@pytest.mark.asyncio +async def test_multitenant_authentication_not_allowed(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = get_credential(tenant_id=expected_tenant, transport=mock.Mock(send=send)) + + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = await credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = await credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled"