Skip to content

Commit

Permalink
simplify TokenExchangeCredential
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell committed Aug 5, 2021
1 parent 45ae1d4 commit a66e926
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, tenant_id, client_id, get_assertion, **kwargs):
"""
self._get_assertion = get_assertion
self._client = AadClient(tenant_id, client_id, **kwargs)
super(ClientAssertionCredential, self).__init__()
super(ClientAssertionCredential, self).__init__(**kwargs)

def __enter__(self):
self._client.__enter__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def __init__(self, **kwargs):
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
from .token_exchange import TokenExchangeCredential

self._credential = TokenExchangeCredential(**kwargs)
self._credential = TokenExchangeCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH],
**kwargs
)
else:
from .imds import ImdsCredential

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import time
from typing import TYPE_CHECKING

from .client_assertion import ClientAssertionCredential
from .._constants import EnvironmentVariables

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from azure.core.credentials import AccessToken


class TokenFileMixin(object):
def __init__(self):
# type: () -> None
def __init__(self, token_file_path, **kwargs): # pylint:disable=unused-argument
# type: (str, **Any) -> None
super(TokenFileMixin, self).__init__()
self._jwt = ""
self._last_read_time = 0
self._token_file_path = os.environ[EnvironmentVariables.TOKEN_FILE_PATH]
self._token_file_path = token_file_path

def get_service_account_token(self):
# type: () -> str
Expand All @@ -33,28 +30,9 @@ def get_service_account_token(self):
return self._jwt


class TokenExchangeCredential(TokenFileMixin):
def __init__(self, **kwargs):
# type: (**Any) -> None
super(TokenExchangeCredential, self).__init__()
self._credential = ClientAssertionCredential(
os.environ[EnvironmentVariables.AZURE_TENANT_ID],
os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
self.get_service_account_token,
**kwargs
class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin):
def __init__(self, token_file_path, **kwargs):
# type: (str, **Any) -> None
super(TokenExchangeCredential, self).__init__(
get_assertion=self.get_service_account_token, token_file_path=token_file_path, **kwargs
)

def __enter__(self):
self._credential.__enter__()
return self

def __exit__(self, *args):
self._credential.__exit__(*args)

def close(self):
# type: () -> None
self.__exit__()

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
return self._credential.get_token(*scopes, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, tenant_id: str, client_id: str, get_assertion: "Callable[[],
"""
self._get_assertion = get_assertion
self._client = AadClient(tenant_id, client_id, **kwargs)
super().__init__()
super().__init__(**kwargs)

async def __aenter__(self):
await self._client.__aenter__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def __init__(self, **kwargs: "Any") -> None:
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
from .token_exchange import TokenExchangeCredential

self._credential = TokenExchangeCredential(**kwargs)
self._credential = TokenExchangeCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH],
**kwargs
)
else:
from .imds import ImdsCredential

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
from typing import TYPE_CHECKING

from .client_assertion import ClientAssertionCredential
from .._internal import AsyncContextManager
from ..._constants import EnvironmentVariables
from ..._credentials.token_exchange import TokenFileMixin

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from azure.core.credentials import AccessToken


class TokenExchangeCredential(AsyncContextManager, TokenFileMixin):
def __init__(self, **kwargs: "Any") -> None:
super().__init__()
self._credential = ClientAssertionCredential(
os.environ[EnvironmentVariables.AZURE_TENANT_ID],
os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
self.get_service_account_token,
**kwargs
)

async def __aenter__(self):
await self._credential.__aenter__()
return self

async def close(self) -> None:
await self._credential.close()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
return await self._credential.get_token(*scopes, **kwargs)
class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin):
def __init__(self, token_file_path: str, **kwargs: "Any") -> None:
super().__init__(get_assertion=self.get_service_account_token, token_file_path=token_file_path, **kwargs)

0 comments on commit a66e926

Please sign in to comment.