Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request management token via Azure CLI only for Service Principals and not human users #408

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,23 +314,42 @@ def __init__(self, resource: str, subscription: str = ""):
access_token_field='accessToken',
expiry_field='expiresOn')

def is_human_user(self) -> bool:
"""The UPN claim is the username of the user, but not the Service Principal.

Azure CLI can be authenticated by both human users (`az login`) and service principals. In case of service
principals, it can be either OIDC from GitHub or login with a password:

~ $ az login --service-principal --user $clientID --password $clientSecret --tenant $tenantID

Human users get more claims:
- 'amr' - how the subject of the token was authenticated
- 'name', 'family_name', 'given_name' - human-readable values that identifies the subject of the token
- 'scp' with `user_impersonation` value, that shows the set of scopes exposed by your application for which
the client application has requested (and received) consent
- 'unique_name' - a human-readable value that identifies the subject of the token. This value is not
guaranteed to be unique within a tenant and should be used only for display purposes.
- 'upn' - The username of the user.
"""
return 'upn' in self.token().jwt_claims()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@staticmethod
def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
subscription = AzureCliTokenSource.get_subscription(cfg)
if subscription != "":
token = AzureCliTokenSource(resource, subscription)
token_source = AzureCliTokenSource(resource, subscription)
try:
# This will fail if the user has access to the workspace, but not to the subscription
# itself.
# In such case, we fall back to not using the subscription.
token.token()
return token
token_source.token()
return token_source
except OSError:
logger.warning("Failed to get token for subscription. Using resource only token.")

token = AzureCliTokenSource(resource)
token.token()
return token
token_source = AzureCliTokenSource(resource)
token_source.token()
return token_source

@staticmethod
def get_subscription(cfg: 'Config') -> str:
Expand All @@ -355,12 +374,13 @@ def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:
doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest'
logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details')
return None
try:
mgmt_token_source = AzureCliTokenSource.for_resource(cfg,
cfg.arm_environment.service_management_endpoint)
except Exception as e:
logger.debug(f'Not including service management token in headers', exc_info=e)
mgmt_token_source = None
if not token_source.is_human_user():
try:
management_endpoint = cfg.arm_environment.service_management_endpoint
mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint)
except Exception as e:
logger.debug(f'Not including service management token in headers', exc_info=e)
mgmt_token_source = None

_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
logger.info("Using Azure CLI authentication with AAD tokens")
Expand Down
37 changes: 37 additions & 0 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,43 @@ def from_dict(raw: dict) -> 'Token':
expiry=datetime.fromisoformat(raw['expiry']),
refresh_token=raw.get('refresh_token'))

def jwt_claims(self) -> Dict[str, str]:
"""Get claims from the access token or return an empty dictionary if it is not a JWT token.

All refreshable tokens we're dealing with are JSON Web Tokens (JWT).

The common claims are:
- 'aud' represents the intended recipient of the token. In case of Azure, this is an app's Application ID
assigned within the Azure portal.
- 'iss' serves to identify the security token service (STS) responsible for creating and delivering the token.
In case of Azure, it includes the Azure AD tenant where user authentication occurred.
- 'appid' stands for the application ID of the client utilizing this token. This application can operate either
autonomously or on behalf of a user. The application ID commonly represents an application object but
may also denote a service principal object in case of Azure.
- 'idp' is used to document the identity provider that authenticated the subject of the token.
- 'oid' is the unchanging identifier for an entity within the identity system.
- 'sub' identifies the primary entity for the token, such as the user of an app. This value is specific to
a particular application ID. If a single user logs into two different apps using distinct client IDs,
these apps will receive different values for the subject claim.
- 'tid' In case of Azure, this value represents Azure Tenant ID.

See https://datatracker.ietf.org/doc/html/rfc7519 for specification.
See https://jwt.ms for debugger.
"""
try:
jwt_split = self.access_token.split(".")
if len(jwt_split) != 3:
logger.debug(f'Tried to decode access token as JWT, but failed: {len(jwt_split)} components')
return {}
payload_with_padding = jwt_split[1] + "=="
payload_bytes = base64.standard_b64decode(payload_with_padding)
payload_json = payload_bytes.decode("utf8")
claims = json.loads(payload_json)
return claims
except ValueError as err:
logger.debug(f'Tried to decode access token as JWT, but failed: {err}')
return {}


class TokenSource:

Expand Down
Loading