Skip to content

Commit

Permalink
Enhance Azure OAuth handler (#1474)
Browse files Browse the repository at this point in the history
* Enhance Azure OAuth handler

* Refactor OAuth callbacks and logging

* Fix flake
  • Loading branch information
philippjfr authored Jul 10, 2020
1 parent c123549 commit 8a093ca
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 71 deletions.
144 changes: 73 additions & 71 deletions panel/auth.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import codecs
import json
import logging
import os
import pkg_resources
import re

try:
from urllib import urlencode
except ImportError:
# python 3
from urllib.parse import urlencode
from urllib.parse import urlencode

import tornado

from bokeh.server.auth_provider import AuthProvider
from tornado.auth import OAuth2Mixin
from tornado.httpclient import HTTPRequest
from tornado.httputil import url_concat

from .config import config
from .io import state
from .util import base64url_encode, base64url_decode

log = logging.getLogger(__file__)



def decode_response_body(response):
Expand Down Expand Up @@ -49,31 +50,22 @@ class OAuthLoginHandler(tornado.web.RequestHandler):
x_site_token = 'application'

async def get_authenticated_user(self, redirect_uri, client_id, state,
client_secret=None, code=None,
success_callback=None,
error_callback=None):
client_secret=None, code=None):
""" Fetches the authenticated user
:param redirect_uri: the redirect URI
:param client_id: the client ID
:param state: the unguessable random string to protect against
cross-site request forgery attacks
:param client_secret: the client secret
:param code: the response code from the server
:param success_callback: the success callback used when fetching
the access token succeeds
:param error_callback: the callback used when fetching the access
token fails
"""
if code:
await self._fetch_access_token(
return await self._fetch_access_token(
code,
success_callback,
error_callback,
redirect_uri,
client_id,
client_secret
)
return

params = {
'redirect_uri': redirect_uri,
Expand All @@ -87,21 +79,17 @@ async def get_authenticated_user(self, redirect_uri, client_id, state,
params['scope'] = self._SCOPE
if 'scope' in config.oauth_extra_params:
params['scope'] = config.oauth_extra_params['scope']
log.info("%s making authorize request" % type(self).__name__)
self.authorize_redirect(**params)

async def _fetch_access_token(self, code, success_callback, error_callback,
redirect_uri, client_id, client_secret):
async def _fetch_access_token(self, code, redirect_uri, client_id, client_secret):
"""
Fetches the access token.
Arguments
----------
code:
The response code from the server
success_callback:
The callback used when fetching the access token succeeds
error_callback:
The callback used when fetching the access token fails
redirect_uri:
The redirect URI
client_id:
Expand All @@ -112,10 +100,10 @@ async def _fetch_access_token(self, code, success_callback, error_callback,
The unguessable random string to protect against cross-site
request forgery attacks
"""
if not (client_secret and success_callback and error_callback):
raise ValueError(
'The client secret or any callbacks are undefined.'
)
if not client_secret:
raise ValueError('The client secret is undefined.')

log.info("%s making access token request." % type(self).__name__)

params = {
'code': code,
Expand Down Expand Up @@ -147,7 +135,7 @@ async def _fetch_access_token(self, code, success_callback, error_callback,
}
if response.error:
data['error'] = response.error
error_callback(**data)
return self._on_error(**data)

user_response = await http.fetch(
'{}{}'.format(
Expand All @@ -160,9 +148,12 @@ async def _fetch_access_token(self, code, success_callback, error_callback,

if not user:
return
success_callback(user, body['access_token'])

log.info("%s received user information." % type(self).__name__)
return self._on_auth(user, body['access_token'])

async def get(self):
log.info("%s received login request" % type(self).__name__)
if config.oauth_redirect_uri:
redirect_uri = config.oauth_redirect_uri
else:
Expand All @@ -177,7 +168,7 @@ async def get(self):
}
# Some OAuth2 backends do not correctly return code
next_code = self.get_argument('next', None)
if 'code=' in next_code:
if next_code and 'code=' in next_code:
url_params = next_code[next_code.index('code='):].replace('code=', '').split('&')
code = url_params[0]
state = [p.replace('state=', '') for p in url_params if p.startswith('state')]
Expand All @@ -191,28 +182,31 @@ async def get(self):
# retrieved from the query string.
params.update({
'client_secret': config.oauth_secret,
'success_callback': self._on_auth,
'error_callback': self._on_error,
'code': code,
'state': state
})
user = await self.get_authenticated_user(**params)
if user is None:
raise tornado.web.HTTPError(403)
log.info("%s authorized user, redirecting to app." % type(self).__name__)
self.redirect('/')
else:
# Redirect for user authentication
await self.get_authenticated_user(**params)
return
# Redirect for user authentication
await self.get_authenticated_user(**params)

def _on_auth(self, user_info, access_token):
user_key = config.oauth_jwt_user or self._USER_KEY
self.set_secure_cookie('user', user_info[user_key])
user = user_info[user_key]
self.set_secure_cookie('user', user)
id_token = base64url_encode(json.dumps(user_info))
if state.encryption:
access_token = state.encryption.encrypt(access_token.encode('utf-8'))
id_token = state.encryption.encrypt(id_token.encode('utf-8'))
self.set_secure_cookie('access_token', access_token)
self.set_secure_cookie('id_token', id_token)
self.redirect('/')
return user

def _on_error(self, user):
def _on_error(self, **kwargs):
self.clear_all_cookies()
name = type(self).__name__.replace('LoginHandler', '')
raise tornado.web.HTTPError(500, '%s authentication failed' % name)
Expand Down Expand Up @@ -282,19 +276,14 @@ def _OAUTH_USER_URL(self):
url = config.oauth_extra_params.get('url', 'gitlab.com')
return self._OAUTH_USER_URL_.format(url)

async def _fetch_access_token(self, code, success_callback, error_callback,
redirect_uri, client_id, client_secret):
async def _fetch_access_token(self, code, redirect_uri, client_id, client_secret):
"""
Fetches the access token.
Arguments
----------
code:
The response code from the server
success_callback:
The callback used when fetching the access token succeeds
error_callback:
The callback used when fetching the access token fails
redirect_uri:
The redirect URI
client_id:
Expand All @@ -305,10 +294,10 @@ async def _fetch_access_token(self, code, success_callback, error_callback,
The unguessable random string to protect against cross-site
request forgery attacks
"""
if not (client_secret and success_callback and error_callback):
raise ValueError(
'The client secret or any callbacks are undefined.'
)
if not client_secret:
raise ValueError('The client secret is undefined.')

log.info("%s making access token request." % type(self).__name__)

http = self.get_auth_http_client()

Expand Down Expand Up @@ -342,11 +331,14 @@ async def _fetch_access_token(self, code, success_callback, error_callback,
}
if response.error:
data['error'] = response.error
error_callback(**data)
return self._on_error(**data)

log.info("%s granted access_token." % type(self).__name__)

headers = dict(self._API_BASE_HEADERS, **{
"Authorization": "Bearer {}".format(body['access_token']),
})

user_response = await http.fetch(
self._OAUTH_USER_URL,
method="GET",
Expand All @@ -357,30 +349,32 @@ async def _fetch_access_token(self, code, success_callback, error_callback,

if not user:
return
success_callback(user, body['access_token'])

log.info("%s received user information." % type(self).__name__)

return self._on_auth(user, body['access_token'])



class OAuthIDTokenLoginHandler(OAuthLoginHandler):

_API_BASE_HEADERS = {
'Content-Type':
'application/x-www-form-urlencoded; charset=UTF-8'
}

_EXTRA_AUTHORIZE_PARAMS = {
'grant_type': 'authorization_code'
}

async def _fetch_access_token(
self, code, success_callback, error_callback, redirect_uri,
client_id, client_secret):
async def _fetch_access_token(self, code, redirect_uri, client_id, client_secret):
"""
Fetches the access token.
Arguments
----------
code:
The response code from the server
success_callback:
The callback used when fetching the access token succeeds
error_callback:
The callback used when fetching the access token fails
redirect_uri:
The redirect URI
client_id:
Expand All @@ -391,10 +385,10 @@ async def _fetch_access_token(
The unguessable random string to protect against cross-site
request forgery attacks
"""
if not (client_secret and success_callback and error_callback):
raise ValueError(
'The client secret or any callbacks are undefined.'
)
if not client_secret:
raise ValueError('The client secret are undefined.')

log.info("%s making access token request." % type(self).__name__)

http = self.get_auth_http_client()

Expand All @@ -406,14 +400,18 @@ async def _fetch_access_token(
**self._EXTRA_AUTHORIZE_PARAMS
}

data = urlencode(
params, doseq=True, encoding='utf-8', safe='=')

# Request the access token.
response = await http.fetch(
req = HTTPRequest(
self._OAUTH_ACCESS_TOKEN_URL,
method='POST',
body=urlencode(params),
headers=self._API_BASE_HEADERS
method="POST",
headers=self._API_BASE_HEADERS,
body=data
)

response = await http.fetch(req)
decoded_body = decode_response_body(response)

if 'access_token' not in decoded_body:
Expand All @@ -424,25 +422,27 @@ async def _fetch_access_token(

if response.error:
data['error'] = response.error
error_callback(**data)
return
return self._on_error(**data)

log.info("%s granted access_token." % type(self).__name__)

access_token = decoded_body['access_token']
id_token = decoded_body['id_token']
success_callback(id_token, access_token)
return self._on_auth(id_token, access_token)

def _on_auth(self, id_token, access_token):
signing_input, _ = id_token.encode('utf-8').rsplit(b".", 1)
_, payload_segment = signing_input.split(b".", 1)
decoded = json.loads(base64url_decode(payload_segment).decode('utf-8'))
user_key = config.oauth_jwt_user or self._USER_KEY
self.set_secure_cookie('user', decoded[user_key])
user = decoded[user_key]
self.set_secure_cookie('user', user)
if state.encryption:
access_token = state.encryption.encrypt(access_token.encode('utf-8'))
id_token = state.encryption.encrypt(id_token.encode('utf-8'))
self.set_secure_cookie('access_token', access_token)
self.set_secure_cookie('id_token', id_token)
self.redirect('/')
return user


class AzureAdLoginHandler(OAuthIDTokenLoginHandler, OAuth2Mixin):
Expand All @@ -460,11 +460,13 @@ class AzureAdLoginHandler(OAuthIDTokenLoginHandler, OAuth2Mixin):

@property
def _OAUTH_ACCESS_TOKEN_URL(self):
return self._OAUTH_ACCESS_TOKEN_URL_.format(**config.oauth_extra_params)
tenant = os.environ.get('AAD_TENANT_ID', config.oauth_extra_params.get('tenant', 'common'))
return self._OAUTH_ACCESS_TOKEN_URL_.format(tenant=tenant)

@property
def _OAUTH_AUTHORIZE_URL(self):
return self._OAUTH_AUTHORIZE_URL_.format(**config.oauth_extra_params)
tenant = os.environ.get('AAD_TENANT_ID', config.oauth_extra_params.get('tenant', 'common'))
return self._OAUTH_AUTHORIZE_URL_.format(tenant=tenant)

@property
def _OAUTH_USER_URL(self):
Expand Down
18 changes: 18 additions & 0 deletions panel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ class _config(param.Parameterized):
_embed_save_path = param.String(default='./', doc="""
Where to save json files for embedded state.""")

_log_level = param.Selector(
default=None, objects=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
doc="Log level of Panel loggers")

_oauth_provider = param.ObjectSelector(
default=None, allow_None=True, objects=[], doc="""
Select between a list of authentification providers.""")
Expand Down Expand Up @@ -258,6 +262,20 @@ def inline(self, value):
validate_config(self, '_inline', value)
self._inline_ = value

@property
def log_level(self):
if self._log_level_ is not None:
return self._log_level_
elif 'PANEL_LOG_LEVEL' in os.environ:
return os.environ['PANEL_LOG_LEVEL'].upper()
else:
return self._log_level

@log_level.setter
def log_level(self, value):
validate_config(self, '_log_level', value)
self._log_level_ = value

@property
def oauth_provider(self):
if self._oauth_provider_ is not None:
Expand Down
Loading

0 comments on commit 8a093ca

Please sign in to comment.