Skip to content

Commit

Permalink
Upgrade to Flask Application Builder 4.3.9 (#35085)
Browse files Browse the repository at this point in the history
This PR brings all the necessary changes to upgrade to FAB 4.3.9 from
4.3.6.

It incorporates those changes:

* dpgaspar/Flask-AppBuilder#2112
* dpgaspar/Flask-AppBuilder#2121

It also removes the limitation of the WTForms after compatibility has
been implemented:

* dpgaspar/Flask-AppBuilder#2138

GitOrigin-RevId: 4198146f49b72d051d82fbd821c7105cf2f4a8bd
  • Loading branch information
potiuk authored and Cloud Composer Team committed Nov 8, 2024
1 parent 04fad69 commit 53f9cc5
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 80 deletions.
83 changes: 21 additions & 62 deletions airflow/auth/managers/fab/security_manager/override.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@
# under the License.
from __future__ import annotations

import base64
import json
import logging
import os
import random
import uuid
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Container, Iterable, Sequence
from typing import TYPE_CHECKING, Any, Container, Iterable, Sequence

import re2
import jwt
from flask import flash, g, session
from flask_appbuilder import const
from flask_appbuilder.const import (
Expand All @@ -41,6 +39,7 @@
LOGMSG_ERR_SEC_AUTH_LDAP_TLS,
LOGMSG_WAR_SEC_LOGIN_FAILED,
LOGMSG_WAR_SEC_NOLDAP_OBJ,
MICROSOFT_KEY_SET_URL,
)
from flask_appbuilder.models.sqla import Base
from flask_appbuilder.models.sqla.interface import SQLAInterface
Expand Down Expand Up @@ -1390,9 +1389,8 @@ def auth_user_db(self, username, password):
log.info(LOGMSG_WAR_SEC_LOGIN_FAILED, username)
return None

def get_oauth_user_info(self, provider, resp):
"""
Get the OAuth user information from different OAuth APIs.
def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, Any]:
"""There are different OAuth APIs with different ways to retrieve user info.
All providers have different ways to retrieve user info.
"""
Expand Down Expand Up @@ -1432,23 +1430,14 @@ def get_oauth_user_info(self, provider, resp):
"last_name": data.get("family_name", ""),
"email": data.get("email", ""),
}
# for Azure AD Tenant. Azure OAuth response contains
# JWT token which has user info.
# JWT token needs to be base64 decoded.
# https://docs.microsoft.com/en-us/azure/active-directory/develop/
# active-directory-protocols-oauth-code
if provider == "azure":
log.debug("Azure response received : %s", resp)
id_token = resp["id_token"]
log.debug(str(id_token))
me = self._azure_jwt_token_parse(id_token)
log.debug("Parse JWT token : %s", me)
me = self._decode_and_validate_azure_jwt(resp["id_token"])
log.debug("User info from Azure: %s", me)
# https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims
return {
"name": me.get("name", ""),
"email": me["upn"],
"email": me["email"],
"first_name": me.get("given_name", ""),
"last_name": me.get("family_name", ""),
"id": me["oid"],
"username": me["oid"],
"role_keys": me.get("roles", []),
}
Expand Down Expand Up @@ -1535,52 +1524,22 @@ def ldap_extract(ldap_dict: dict[str, list[bytes]], field_name: str, fallback: s
---------------
"""

@staticmethod
def _azure_parse_jwt(token):
"""
Parse Azure JWT token content.
def _get_microsoft_jwks(self) -> list[dict[str, Any]]:
import requests

:param token: the JWT token
return requests.get(MICROSOFT_KEY_SET_URL).json()

:meta private:
"""
jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$"
matches = re2.search(jwt_token_parts, token)
if not matches or len(matches.groups()) < 3:
log.error("Unable to parse token.")
return {}
return {
"header": matches.group(1),
"Payload": matches.group(2),
"Sig": matches.group(3),
}

@staticmethod
def _azure_jwt_token_parse(token):
"""
Parse and decode Azure JWT token.
:param token: the JWT token
:meta private:
"""
jwt_split_token = FabAirflowSecurityManagerOverride._azure_parse_jwt(token)
if not jwt_split_token:
return

jwt_payload = jwt_split_token["Payload"]
# Prepare for base64 decoding
payload_b64_string = jwt_payload
payload_b64_string += "=" * (4 - (len(jwt_payload) % 4))
decoded_payload = base64.urlsafe_b64decode(payload_b64_string.encode("ascii"))

if not decoded_payload:
log.error("Payload of id_token could not be base64 url decoded.")
return
def _decode_and_validate_azure_jwt(self, id_token: str) -> dict[str, str]:
verify_signature = self.oauth_remotes["azure"].client_kwargs.get("verify_signature", False)
if verify_signature:
from authlib.jose import JsonWebKey, jwt as authlib_jwt

jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8"))
keyset = JsonWebKey.import_key_set(self._get_microsoft_jwks())
claims = authlib_jwt.decode(id_token, keyset)
claims.validate()
return claims

return jwt_decoded_payload
return jwt.decode(id_token, options={"verify_signature": False})

def _ldap_bind_indirect(self, ldap, con) -> None:
"""
Expand Down
19 changes: 8 additions & 11 deletions airflow/www/fab_security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import datetime
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Callable
from uuid import uuid4

import re2
Expand Down Expand Up @@ -276,7 +276,10 @@ def current_user(self):
elif current_user_jwt:
return current_user_jwt

def oauth_user_info_getter(self, f):
def oauth_user_info_getter(
self,
func: Callable[[BaseSecurityManager, str, dict[str, Any] | None], dict[str, Any]],
):
"""
Get OAuth user info; used by all providers.
Expand All @@ -290,17 +293,11 @@ def my_oauth_user_info(sm, provider, response=None):
if provider == 'github':
me = sm.oauth_remotes[provider].get('user')
return {'username': me.data.get('login')}
else:
return {}
return {}
"""

def wraps(provider, response=None):
ret = f(self, provider, response=response)
# Checks if decorator is well behaved and returns a dict as supposed.
if not isinstance(ret, dict):
log.error("OAuth user info decorated function did not returned a dict, but: %s", type(ret))
return {}
return ret
def wraps(provider: str, response: dict[str, Any] | None = None) -> dict[str, Any]:
return func(self, provider, response)

self.oauth_user_info = wraps
return wraps
Expand Down
7 changes: 1 addition & 6 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,6 @@ setup_requires =
# DEPENDENCIES_EPOCH_NUMBER in the Dockerfile.ci
#####################################################################################################
install_requires =
# WTForms 3.1.0 released 10th of October 2023 introduced a breaking change in the way
# QuerySelectChoices fields are handled.
# See details in https://github.com/dpgaspar/Flask-AppBuilder/issues/2137
# We should remove this limitation when the issue is fixed
WTForms<3.1.0
# Alembic is important to handle our migrations in predictable and performant way. It is developed
# together with SQLAlchemy. Our experience with Alembic is that it very stable in minor version
alembic>=1.6.3, <2.0
Expand Down Expand Up @@ -98,7 +93,7 @@ install_requires =
# `airflow/www/fab_security` with their upstream counterparts. In particular, make sure any breaking changes,
# for example any new methods, are accounted for.
# NOTE! When you change the value here, you also have to update flask-appbuilder[oauth] in setup.py
flask-appbuilder==4.3.6
flask-appbuilder==4.3.9
flask-caching>=1.5.0
flask-login>=0.6.2
flask-session>=0.4.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve
flask_appbuilder_oauth = [
"authlib>=1.0.0",
# The version here should be upgraded at the same time as flask-appbuilder in setup.cfg
"flask-appbuilder[oauth]==4.3.6",
"flask-appbuilder[oauth]==4.3.9",
]
kerberos = [
"pykerberos>=1.1.13",
Expand Down

0 comments on commit 53f9cc5

Please sign in to comment.