Skip to content

Commit

Permalink
Refactor SharedTokenCacheCredential (#19914)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jul 28, 2021
1 parent 4192b56 commit ef54724
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 105 deletions.
137 changes: 35 additions & 102 deletions sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import time

from msal.application import PublicClientApplication

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from typing import TYPE_CHECKING

from .silent import SilentAuthenticationCredential
from .. import CredentialUnavailableError
from .._constants import DEVELOPER_SIGN_ON_CLIENT_ID
from .._internal import AadClient, resolve_tenant, validate_tenant_id
from .._internal.decorators import log_get_token, wrap_exceptions
from .._internal.msal_client import MsalClient
from .._internal import AadClient
from .._internal.decorators import log_get_token
from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Dict, Optional
from .. import AuthenticationRecord
from typing import Any, Optional
from azure.core.credentials import TokenCredential
from .._internal import AadClientBase


class SharedTokenCacheCredential(SharedTokenCacheBase):
class SharedTokenCacheCredential(object):
"""Authenticates using tokens in the local cache shared between Microsoft applications.
:param str username:
Username (typically an email address) of the user to authenticate as. This is used when the local cache
contains tokens for multiple identities.
:param str username: Username (typically an email address) of the user to authenticate as. This is used when the
local cache contains tokens for multiple identities.
:keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
Expand All @@ -55,21 +43,13 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
def __init__(self, username=None, **kwargs):
# type: (Optional[str], **Any) -> None

self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
if self._auth_record:
# authenticate in the tenant that produced the record unless "tenant_id" specifies another
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._allow_multitenant = kwargs.pop("allow_multitenant_authentication", False)
self._cache = kwargs.pop("_cache", None)
self._client_applications = {} # type: Dict[str, PublicClientApplication]
self._msal_client = MsalClient(**kwargs)
self._initialized = False
if "authentication_record" in kwargs:
self._credential = SilentAuthenticationCredential(**kwargs) # type: TokenCredential
else:
super(SharedTokenCacheCredential, self).__init__(username=username, **kwargs)
self._credential = _SharedTokenCacheCredential(username=username, **kwargs)

@log_get_token("SharedTokenCacheCredential")
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
def get_token(self, *scopes, **kwargs):
# type (*str, **Any) -> AccessToken
"""Get an access token for `scopes` from the shared cache.
Expand All @@ -78,14 +58,34 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
This method is called automatically by Azure SDK clients.
:param str scopes: desired scopes for the access token. This method requires at least one scope.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
claims challenge following an authorization failure
:rtype: :class:`azure.core.credentials.AccessToken`
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason.
attribute gives a reason.
"""
return self._credential.get_token(*scopes, **kwargs)

@staticmethod
def supported():
# type: () -> bool
"""Whether the shared token cache is supported on the current platform.
:rtype: bool
"""
return SharedTokenCacheBase.supported()


class _SharedTokenCacheCredential(SharedTokenCacheBase):
"""The original SharedTokenCacheCredential, which doesn't use msal.ClientApplication"""

def get_token(self, *scopes, **kwargs):
# type (*str, **Any) -> AccessToken
if not scopes:
raise ValueError("'get_token' requires at least one scope")

Expand All @@ -95,9 +95,6 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not self._cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")

if self._auth_record:
return self._acquire_token_silent(*scopes, **kwargs)

account = self._get_account(self._username, self._tenant_id)

token = self._get_cached_access_token(scopes, account)
Expand All @@ -114,67 +111,3 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
def _get_auth_client(self, **kwargs):
# type: (**Any) -> AadClientBase
return AadClient(client_id=DEVELOPER_SIGN_ON_CLIENT_ID, **kwargs)

def _initialize(self):
if self._initialized:
return

if not self._auth_record:
super(SharedTokenCacheCredential, self)._initialize()
return

self._load_cache()
self._initialized = True

def _get_client_application(self, **kwargs):
tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs)
if tenant_id not in self._client_applications:
# CP1 = can handle claims challenges (CAE)
capabilities = None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"]
self._client_applications[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=self._cache,
http_client=self._msal_client,
client_capabilities=capabilities
)
return self._client_applications[tenant_id]

@wrap_exceptions
def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Silently acquire a token from MSAL. Requires an AuthenticationRecord."""

# this won't be None when this method is called by get_token but we check anyway to satisfy mypy
if self._auth_record is None:
raise CredentialUnavailableError("Initialization failed")

result = None

client_application = self._get_client_application(**kwargs)
accounts_for_user = client_application.get_accounts(username=self._auth_record.username)
if not accounts_for_user:
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")

for account in accounts_for_user:
if account.get("home_account_id") != self._auth_record.home_account_id:
continue

now = int(time.time())
result = client_application.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
if result:
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it
message = "Token acquisition failed"
details = result.get("error_description") or result.get("error")
if details:
message += ": {}".format(details)
raise ClientAuthenticationError(message=message)

# cache doesn't contain a matching refresh (or access) token
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
116 changes: 116 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_credentials/silent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import platform
import time
from typing import TYPE_CHECKING

from msal import PublicClientApplication

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from .. import CredentialUnavailableError
from .._internal import resolve_tenant, validate_tenant_id
from .._internal.decorators import wrap_exceptions
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Dict
from .. import AuthenticationRecord


class SilentAuthenticationCredential(object):
"""Internal class for authenticating from the default shared cache given an AuthenticationRecord"""

def __init__(self, authentication_record, **kwargs):
# type: (AuthenticationRecord, **Any) -> None
self._auth_record = authentication_record

# authenticate in the tenant that produced the record unless "tenant_id" specifies another
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._allow_multitenant = kwargs.pop("allow_multitenant_authentication", False)
self._cache = kwargs.pop("_cache", None)
self._client_applications = {} # type: Dict[str, PublicClientApplication]
self._client = MsalClient(**kwargs)
self._initialized = False

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type (*str, **Any) -> AccessToken
if not scopes:
raise ValueError('"get_token" requires at least one scope')

if not self._initialized:
self._initialize()

if not self._cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")

return self._acquire_token_silent(*scopes, **kwargs)

def _initialize(self):
if not self._cache and platform.system() in {"Darwin", "Linux", "Windows"}:
try:
# This credential accepts the user's default cache regardless of whether it's encrypted. It doesn't
# create a new cache. If the default cache exists, the user must have created it earlier. If it's
# unencrypted, the user must have allowed that.
self._cache = _load_persistent_cache(TokenCachePersistenceOptions(allow_unencrypted_storage=True))
except Exception: # pylint:disable=broad-except
pass

self._initialized = True

def _get_client_application(self, **kwargs):
tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs)
if tenant_id not in self._client_applications:
# CP1 = can handle claims challenges (CAE)
capabilities = None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"]
self._client_applications[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=self._cache,
http_client=self._client,
client_capabilities=capabilities
)
return self._client_applications[tenant_id]

@wrap_exceptions
def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Silently acquire a token from MSAL."""

result = None

client_application = self._get_client_application(**kwargs)
accounts_for_user = client_application.get_accounts(username=self._auth_record.username)
if not accounts_for_user:
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")

for account in accounts_for_user:
if account.get("home_account_id") != self._auth_record.home_account_id:
continue

now = int(time.time())
result = client_application.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
if result:
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it
message = "Token acquisition failed"
details = result.get("error_description") or result.get("error")
if details:
message += ": {}".format(details)
raise ClientAuthenticationError(message=message)

# cache doesn't contain a matching refresh (or access) token
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def send(request, **_):
transport = Mock(send=send)
credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())

with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication:
with pytest.raises(ClientAuthenticationError): # (cache is empty)
credential.get_token("scope")

Expand All @@ -761,7 +761,7 @@ def send(request, **_):
assert kwargs["client_capabilities"] == ["CP1"]

credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())
with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication:
with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}):
with pytest.raises(ClientAuthenticationError): # (cache is empty)
credential.get_token("scope")
Expand All @@ -786,7 +786,7 @@ def test_claims_challenge():

transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent")))
credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())
with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication", lambda *_, **__: msal_app):
with patch("azure.identity._credentials.silent.PublicClientApplication", lambda *_, **__: msal_app):
credential.get_token("scope", claims=expected_claims)

assert msal_app.acquire_token_silent_with_error.call_count == 1
Expand Down

0 comments on commit ef54724

Please sign in to comment.