diff --git a/airflow/auth/managers/fab/security_manager/override.py b/airflow/auth/managers/fab/security_manager/override.py index be1fffa9547..96b40a12afb 100644 --- a/airflow/auth/managers/fab/security_manager/override.py +++ b/airflow/auth/managers/fab/security_manager/override.py @@ -17,17 +17,15 @@ # under the License. from __future__ import annotations -import base64 -import json import logging import os import random import uuid import warnings from functools import cached_property -from typing import TYPE_CHECKING, Container, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Container, Iterable, Sequence -import re2 +import jwt from flask import flash, g, session from flask_appbuilder import const from flask_appbuilder.const import ( @@ -41,6 +39,7 @@ LOGMSG_ERR_SEC_AUTH_LDAP_TLS, LOGMSG_WAR_SEC_LOGIN_FAILED, LOGMSG_WAR_SEC_NOLDAP_OBJ, + MICROSOFT_KEY_SET_URL, ) from flask_appbuilder.models.sqla import Base from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -1390,9 +1389,8 @@ def auth_user_db(self, username, password): log.info(LOGMSG_WAR_SEC_LOGIN_FAILED, username) return None - def get_oauth_user_info(self, provider, resp): - """ - Get the OAuth user information from different OAuth APIs. + def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, Any]: + """There are different OAuth APIs with different ways to retrieve user info. All providers have different ways to retrieve user info. """ @@ -1432,23 +1430,14 @@ def get_oauth_user_info(self, provider, resp): "last_name": data.get("family_name", ""), "email": data.get("email", ""), } - # for Azure AD Tenant. Azure OAuth response contains - # JWT token which has user info. - # JWT token needs to be base64 decoded. - # https://docs.microsoft.com/en-us/azure/active-directory/develop/ - # active-directory-protocols-oauth-code if provider == "azure": - log.debug("Azure response received : %s", resp) - id_token = resp["id_token"] - log.debug(str(id_token)) - me = self._azure_jwt_token_parse(id_token) - log.debug("Parse JWT token : %s", me) + me = self._decode_and_validate_azure_jwt(resp["id_token"]) + log.debug("User info from Azure: %s", me) + # https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims return { - "name": me.get("name", ""), - "email": me["upn"], + "email": me["email"], "first_name": me.get("given_name", ""), "last_name": me.get("family_name", ""), - "id": me["oid"], "username": me["oid"], "role_keys": me.get("roles", []), } @@ -1535,52 +1524,22 @@ def ldap_extract(ldap_dict: dict[str, list[bytes]], field_name: str, fallback: s --------------- """ - @staticmethod - def _azure_parse_jwt(token): - """ - Parse Azure JWT token content. + def _get_microsoft_jwks(self) -> list[dict[str, Any]]: + import requests - :param token: the JWT token + return requests.get(MICROSOFT_KEY_SET_URL).json() - :meta private: - """ - jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$" - matches = re2.search(jwt_token_parts, token) - if not matches or len(matches.groups()) < 3: - log.error("Unable to parse token.") - return {} - return { - "header": matches.group(1), - "Payload": matches.group(2), - "Sig": matches.group(3), - } - - @staticmethod - def _azure_jwt_token_parse(token): - """ - Parse and decode Azure JWT token. - - :param token: the JWT token - - :meta private: - """ - jwt_split_token = FabAirflowSecurityManagerOverride._azure_parse_jwt(token) - if not jwt_split_token: - return - - jwt_payload = jwt_split_token["Payload"] - # Prepare for base64 decoding - payload_b64_string = jwt_payload - payload_b64_string += "=" * (4 - (len(jwt_payload) % 4)) - decoded_payload = base64.urlsafe_b64decode(payload_b64_string.encode("ascii")) - - if not decoded_payload: - log.error("Payload of id_token could not be base64 url decoded.") - return + def _decode_and_validate_azure_jwt(self, id_token: str) -> dict[str, str]: + verify_signature = self.oauth_remotes["azure"].client_kwargs.get("verify_signature", False) + if verify_signature: + from authlib.jose import JsonWebKey, jwt as authlib_jwt - jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8")) + keyset = JsonWebKey.import_key_set(self._get_microsoft_jwks()) + claims = authlib_jwt.decode(id_token, keyset) + claims.validate() + return claims - return jwt_decoded_payload + return jwt.decode(id_token, options={"verify_signature": False}) def _ldap_bind_indirect(self, ldap, con) -> None: """ diff --git a/airflow/www/fab_security/manager.py b/airflow/www/fab_security/manager.py index b9d63faf565..b779ff4a12c 100644 --- a/airflow/www/fab_security/manager.py +++ b/airflow/www/fab_security/manager.py @@ -20,7 +20,7 @@ import datetime import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable from uuid import uuid4 import re2 @@ -276,7 +276,10 @@ def current_user(self): elif current_user_jwt: return current_user_jwt - def oauth_user_info_getter(self, f): + def oauth_user_info_getter( + self, + func: Callable[[BaseSecurityManager, str, dict[str, Any] | None], dict[str, Any]], + ): """ Get OAuth user info; used by all providers. @@ -290,17 +293,11 @@ def my_oauth_user_info(sm, provider, response=None): if provider == 'github': me = sm.oauth_remotes[provider].get('user') return {'username': me.data.get('login')} - else: - return {} + return {} """ - def wraps(provider, response=None): - ret = f(self, provider, response=response) - # Checks if decorator is well behaved and returns a dict as supposed. - if not isinstance(ret, dict): - log.error("OAuth user info decorated function did not returned a dict, but: %s", type(ret)) - return {} - return ret + def wraps(provider: str, response: dict[str, Any] | None = None) -> dict[str, Any]: + return func(self, provider, response) self.oauth_user_info = wraps return wraps diff --git a/setup.cfg b/setup.cfg index b4adb9267c5..ad7e00043bb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,11 +66,6 @@ setup_requires = # DEPENDENCIES_EPOCH_NUMBER in the Dockerfile.ci ##################################################################################################### install_requires = - # WTForms 3.1.0 released 10th of October 2023 introduced a breaking change in the way - # QuerySelectChoices fields are handled. - # See details in https://github.com/dpgaspar/Flask-AppBuilder/issues/2137 - # We should remove this limitation when the issue is fixed - WTForms<3.1.0 # Alembic is important to handle our migrations in predictable and performant way. It is developed # together with SQLAlchemy. Our experience with Alembic is that it very stable in minor version alembic>=1.6.3, <2.0 @@ -98,7 +93,7 @@ install_requires = # `airflow/www/fab_security` with their upstream counterparts. In particular, make sure any breaking changes, # for example any new methods, are accounted for. # NOTE! When you change the value here, you also have to update flask-appbuilder[oauth] in setup.py - flask-appbuilder==4.3.6 + flask-appbuilder==4.3.9 flask-caching>=1.5.0 flask-login>=0.6.2 flask-session>=0.4.0 diff --git a/setup.py b/setup.py index 0b5a03b38ab..84ca00c0c91 100644 --- a/setup.py +++ b/setup.py @@ -328,7 +328,7 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve flask_appbuilder_oauth = [ "authlib>=1.0.0", # The version here should be upgraded at the same time as flask-appbuilder in setup.cfg - "flask-appbuilder[oauth]==4.3.6", + "flask-appbuilder[oauth]==4.3.9", ] kerberos = [ "pykerberos>=1.1.13",