Skip to content

Commit

Permalink
Enhance Azure OAuth handler
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Jul 10, 2020
1 parent c123549 commit 6fc6081
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions panel/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import codecs
import json
import logging
import os
import pkg_resources
import re

Expand All @@ -13,12 +15,15 @@

from bokeh.server.auth_provider import AuthProvider
from tornado.auth import OAuth2Mixin
from tornado.httpclient import HTTPRequest
from tornado.httputil import url_concat

from .config import config
from .io import state
from .util import base64url_encode, base64url_decode

log = logging.getLogger(__file__)



def decode_response_body(response):
Expand Down Expand Up @@ -163,6 +168,7 @@ async def _fetch_access_token(self, code, success_callback, error_callback,
success_callback(user, body['access_token'])

async def get(self):
log.info("%s received login request" % type(self).__name__)
if config.oauth_redirect_uri:
redirect_uri = config.oauth_redirect_uri
else:
Expand All @@ -177,7 +183,7 @@ async def get(self):
}
# Some OAuth2 backends do not correctly return code
next_code = self.get_argument('next', None)
if 'code=' in next_code:
if next_code and 'code=' in next_code:
url_params = next_code[next_code.index('code='):].replace('code=', '').split('&')
code = url_params[0]
state = [p.replace('state=', '') for p in url_params if p.startswith('state')]
Expand Down Expand Up @@ -363,6 +369,11 @@ async def _fetch_access_token(self, code, success_callback, error_callback,

class OAuthIDTokenLoginHandler(OAuthLoginHandler):

_API_BASE_HEADERS = {
'Content-Type':
'application/x-www-form-urlencoded; charset=UTF-8'
}

_EXTRA_AUTHORIZE_PARAMS = {
'grant_type': 'authorization_code'
}
Expand Down Expand Up @@ -406,15 +417,24 @@ async def _fetch_access_token(
**self._EXTRA_AUTHORIZE_PARAMS
}

# Request the access token.
response = await http.fetch(
self._OAUTH_ACCESS_TOKEN_URL,
method='POST',
body=urlencode(params),
headers=self._API_BASE_HEADERS
)
try:
data = urllib.parse.urlencode(
params, doseq=True, encoding='utf-8', safe='=')

# Request the access token.
req = HTTPRequest(
self._OAUTH_ACCESS_TOKEN_URL,
method="POST",
headers=self._API_BASE_HEADERS,
body=data
)

decoded_body = decode_response_body(response)
response = await http.fetch(req)
decoded_body = decode_response_body(response)
except Exception:
import requests
response = requests.post(self._OAUTH_ACCESS_TOKEN_URL, data=params)
decoded_body = response.json()

if 'access_token' not in decoded_body:
data = {
Expand Down Expand Up @@ -460,11 +480,13 @@ class AzureAdLoginHandler(OAuthIDTokenLoginHandler, OAuth2Mixin):

@property
def _OAUTH_ACCESS_TOKEN_URL(self):
return self._OAUTH_ACCESS_TOKEN_URL_.format(**config.oauth_extra_params)
tenant = os.environ.get('AAD_TENANT_ID', config.oauth_extra_params.get('tenant', 'common'))
return self._OAUTH_ACCESS_TOKEN_URL_.format(tenant=tenant)

@property
def _OAUTH_AUTHORIZE_URL(self):
return self._OAUTH_AUTHORIZE_URL_.format(**config.oauth_extra_params)
tenant = os.environ.get('AAD_TENANT_ID', config.oauth_extra_params.get('tenant', 'common'))
return self._OAUTH_AUTHORIZE_URL_.format(tenant=tenant)

@property
def _OAUTH_USER_URL(self):
Expand Down

0 comments on commit 6fc6081

Please sign in to comment.