Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SharedTokenCacheCredential #19914

Merged
merged 2 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import time

from msal.application import PublicClientApplication

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from typing import TYPE_CHECKING

from .silent import SilentAuthenticationCredential
from .. import CredentialUnavailableError
from .._constants import DEVELOPER_SIGN_ON_CLIENT_ID
from .._internal import AadClient, resolve_tenant, validate_tenant_id
from .._internal.decorators import log_get_token, wrap_exceptions
from .._internal.msal_client import MsalClient
from .._internal import AadClient
from .._internal.decorators import log_get_token
from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Dict, Optional
from .. import AuthenticationRecord
from typing import Any, Optional
from azure.core.credentials import TokenCredential
from .._internal import AadClientBase


class SharedTokenCacheCredential(SharedTokenCacheBase):
class SharedTokenCacheCredential(object):
"""Authenticates using tokens in the local cache shared between Microsoft applications.

:param str username:
Username (typically an email address) of the user to authenticate as. This is used when the local cache
contains tokens for multiple identities.
:param str username: Username (typically an email address) of the user to authenticate as. This is used when the
local cache contains tokens for multiple identities.

:keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
Expand All @@ -55,21 +43,13 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
def __init__(self, username=None, **kwargs):
# type: (Optional[str], **Any) -> None

self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
if self._auth_record:
# authenticate in the tenant that produced the record unless "tenant_id" specifies another
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._allow_multitenant = kwargs.pop("allow_multitenant_authentication", False)
self._cache = kwargs.pop("_cache", None)
self._client_applications = {} # type: Dict[str, PublicClientApplication]
self._msal_client = MsalClient(**kwargs)
self._initialized = False
if "authentication_record" in kwargs:
self._credential = SilentAuthenticationCredential(**kwargs) # type: TokenCredential
else:
super(SharedTokenCacheCredential, self).__init__(username=username, **kwargs)
self._credential = _SharedTokenCacheCredential(username=username, **kwargs)

@log_get_token("SharedTokenCacheCredential")
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
def get_token(self, *scopes, **kwargs):
# type (*str, **Any) -> AccessToken
"""Get an access token for `scopes` from the shared cache.

Expand All @@ -78,14 +58,34 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
This method is called automatically by Azure SDK clients.

:param str scopes: desired scopes for the access token. This method requires at least one scope.

:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
claims challenge following an authorization failure

:rtype: :class:`azure.core.credentials.AccessToken`

:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason.
attribute gives a reason.
"""
return self._credential.get_token(*scopes, **kwargs)

@staticmethod
def supported():
# type: () -> bool
"""Whether the shared token cache is supported on the current platform.

:rtype: bool
"""
return SharedTokenCacheBase.supported()


class _SharedTokenCacheCredential(SharedTokenCacheBase):
"""The original SharedTokenCacheCredential, which doesn't use msal.ClientApplication"""

def get_token(self, *scopes, **kwargs):
# type (*str, **Any) -> AccessToken
if not scopes:
raise ValueError("'get_token' requires at least one scope")

Expand All @@ -95,9 +95,6 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not self._cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")

if self._auth_record:
return self._acquire_token_silent(*scopes, **kwargs)

account = self._get_account(self._username, self._tenant_id)

token = self._get_cached_access_token(scopes, account)
Expand All @@ -114,67 +111,3 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
def _get_auth_client(self, **kwargs):
# type: (**Any) -> AadClientBase
return AadClient(client_id=DEVELOPER_SIGN_ON_CLIENT_ID, **kwargs)

def _initialize(self):
if self._initialized:
return

if not self._auth_record:
super(SharedTokenCacheCredential, self)._initialize()
return

self._load_cache()
self._initialized = True

def _get_client_application(self, **kwargs):
tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs)
if tenant_id not in self._client_applications:
# CP1 = can handle claims challenges (CAE)
capabilities = None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"]
self._client_applications[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=self._cache,
http_client=self._msal_client,
client_capabilities=capabilities
)
return self._client_applications[tenant_id]

@wrap_exceptions
def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Silently acquire a token from MSAL. Requires an AuthenticationRecord."""

# this won't be None when this method is called by get_token but we check anyway to satisfy mypy
if self._auth_record is None:
raise CredentialUnavailableError("Initialization failed")

result = None

client_application = self._get_client_application(**kwargs)
accounts_for_user = client_application.get_accounts(username=self._auth_record.username)
if not accounts_for_user:
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")

for account in accounts_for_user:
if account.get("home_account_id") != self._auth_record.home_account_id:
continue

now = int(time.time())
result = client_application.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
if result:
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it
message = "Token acquisition failed"
details = result.get("error_description") or result.get("error")
if details:
message += ": {}".format(details)
raise ClientAuthenticationError(message=message)

# cache doesn't contain a matching refresh (or access) token
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
116 changes: 116 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_credentials/silent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import platform
import time
from typing import TYPE_CHECKING

from msal import PublicClientApplication

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from .. import CredentialUnavailableError
from .._internal import resolve_tenant, validate_tenant_id
from .._internal.decorators import wrap_exceptions
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Dict
from .. import AuthenticationRecord


class SilentAuthenticationCredential(object):
"""Internal class for authenticating from the default shared cache given an AuthenticationRecord"""

def __init__(self, authentication_record, **kwargs):
# type: (AuthenticationRecord, **Any) -> None
self._auth_record = authentication_record

# authenticate in the tenant that produced the record unless "tenant_id" specifies another
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._allow_multitenant = kwargs.pop("allow_multitenant_authentication", False)
self._cache = kwargs.pop("_cache", None)
self._client_applications = {} # type: Dict[str, PublicClientApplication]
self._client = MsalClient(**kwargs)
self._initialized = False

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type (*str, **Any) -> AccessToken
if not scopes:
raise ValueError('"get_token" requires at least one scope')

if not self._initialized:
self._initialize()

if not self._cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")

return self._acquire_token_silent(*scopes, **kwargs)

def _initialize(self):
if not self._cache and platform.system() in {"Darwin", "Linux", "Windows"}:
try:
# This credential accepts the user's default cache regardless of whether it's encrypted. It doesn't
# create a new cache. If the default cache exists, the user must have created it earlier. If it's
# unencrypted, the user must have allowed that.
self._cache = _load_persistent_cache(TokenCachePersistenceOptions(allow_unencrypted_storage=True))
except Exception: # pylint:disable=broad-except
pass

self._initialized = True

def _get_client_application(self, **kwargs):
tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs)
if tenant_id not in self._client_applications:
# CP1 = can handle claims challenges (CAE)
capabilities = None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"]
self._client_applications[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=self._cache,
http_client=self._client,
client_capabilities=capabilities
)
return self._client_applications[tenant_id]

@wrap_exceptions
def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Silently acquire a token from MSAL."""

result = None

client_application = self._get_client_application(**kwargs)
accounts_for_user = client_application.get_accounts(username=self._auth_record.username)
if not accounts_for_user:
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")

for account in accounts_for_user:
if account.get("home_account_id") != self._auth_record.home_account_id:
continue

now = int(time.time())
result = client_application.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
if result:
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it
message = "Token acquisition failed"
details = result.get("error_description") or result.get("error")
if details:
message += ": {}".format(details)
raise ClientAuthenticationError(message=message)

# cache doesn't contain a matching refresh (or access) token
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def send(request, **_):
transport = Mock(send=send)
credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())

with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication:
with pytest.raises(ClientAuthenticationError): # (cache is empty)
credential.get_token("scope")

Expand All @@ -761,7 +761,7 @@ def send(request, **_):
assert kwargs["client_capabilities"] == ["CP1"]

credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())
with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication:
with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}):
with pytest.raises(ClientAuthenticationError): # (cache is empty)
credential.get_token("scope")
Expand All @@ -786,7 +786,7 @@ def test_claims_challenge():

transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent")))
credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())
with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication", lambda *_, **__: msal_app):
with patch("azure.identity._credentials.silent.PublicClientApplication", lambda *_, **__: msal_app):
credential.get_token("scope", claims=expected_claims)

assert msal_app.acquire_token_silent_with_error.call_count == 1
Expand Down