Skip to content

Commit

Permalink
fix #4106: proprietary oauth
Browse files Browse the repository at this point in the history
  • Loading branch information
e-carlin committed May 30, 2022
1 parent 5d2d18a commit e084efd
Show file tree
Hide file tree
Showing 18 changed files with 274 additions and 92 deletions.
9 changes: 9 additions & 0 deletions etc/run-flash.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash
set -eou pipefail

export SIREPO_FEATURE_CONFIG_PROPRIETARY_OAUTH_SIM_TYPES=flash
if [[ ! ${SIREPO_SIM_OAUTH_FLASH_KEY:-} || ! ${SIREPO_SIM_OAUTH_FLASH_SECRET:-} ]]; then
echo 'You must set $SIREPO_SIM_OAUTH_FLASH_KEY and $SIREPO_SIM_OAUTH_FLASH_SECRET' 1>&2
exit 1
fi
sirepo service http
4 changes: 1 addition & 3 deletions etc/run-jupyterhub.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ elif [[ ! $SIREPO_AUTH_METHODS =~ 'email' ]]; then
export SIREPO_AUTH_METHODS=$SIREPO_AUTH_METHODS:email
fi

if [[ ${SIREPO_AUTH_GITHUB_KEY:-} || ${SIREPO_AUTH_GITHUB_SECRET:-} ]]; then
export SIREPO_AUTH_GITHUB_METHOD_VISIBLE=0
export SIREPO_AUTH_METHODS="$SIREPO_AUTH_METHODS:github"
if [[ ${SIREPO_SIM_OAUTH_JUPYTERHUBLOGIN_KEY:-} || ${SIREPO_SIM_OAUTH_JUPYTERHUBLOGIN_SECRET:-} ]]; then
export SIREPO_SIM_API_JUPYTERHUBLOGIN_RS_JUPYTER_MIGRATE=1
fi

Expand Down
26 changes: 18 additions & 8 deletions sirepo/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,21 @@ def api_authCompleteRegistration(self):
_parse_display_name(self.parse_json().get('displayName')),
)
return http_reply.gen_json_ok()


@api_perm.allow_visitor
def api_authState(self):
return http_reply.render_static_jinja(
'auth-state',
'js',
PKDict(auth_state=_auth_state()),
)


@api_perm.allow_visitor
def api_authLogout(self, simulation_type=None):
"""Set the current user as logged out.
Redirects to root simulation page.
"""
req = None
Expand Down Expand Up @@ -361,6 +361,14 @@ def _moderate(uid, role):
require_email_user()
raise sirepo.util.SRException('moderationRequest', None)

def _oauth_redirect(role):
import sirepo.oauth
raise util.Redirect(
sirepo.oauth.create_authorize_redirect(
sirepo.auth_role.sim_type(role),
)
)

if sim_type not in sirepo.feature_config.auth_controlled_sim_types():
return
u = _assert_login()
Expand All @@ -369,9 +377,11 @@ def _moderate(uid, role):
r = sirepo.auth_role.for_sim_type(sim_type)
if auth_db.UserRole.has_role(u, r):
return
if r not in sirepo.auth_role.for_moderated_sim_types():
sirepo.util.raise_forbidden(f'uid={u} does not have access to sim_type={sim_type}')
_moderate(u, r)
elif r in sirepo.auth_role.for_proprietary_oauth_sim_types():
_oauth_redirect(r)
if r in sirepo.auth_role.for_moderated_sim_types():
_moderate(u, r)
sirepo.util.raise_forbidden(f'uid={u} does not have access to sim_type={sim_type}')


def require_email_user():
Expand Down
82 changes: 25 additions & 57 deletions sirepo/auth/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@
from sirepo import auth
from sirepo import auth_db
from sirepo import cookie
from sirepo import feature_config
from sirepo import http_reply
from sirepo import uri_router
from sirepo import util
import authlib.integrations.requests_client
import authlib.oauth2.rfc6749.errors
import flask
import sirepo.events
import sirepo.oauth
import sirepo.request
import sqlalchemy

Expand All @@ -34,33 +31,14 @@
#: Well known alias for auth
UserModel = None

#: module handle
this_module = pkinspect.this_module()

#: cookie keys for github (prefix is "srag")
_COOKIE_NONCE = 'sragn'
_COOKIE_SIM_TYPE = 'srags'


class Request(sirepo.request.Base):
@api_perm.allow_cookieless_set_user
def api_authGithubAuthorized(self):
"""Handle a callback from a successful OAUTH request.
Tracks oauth users in a database.
"""
# clear temporary cookie values first
s = cookie.unchecked_remove(_COOKIE_NONCE)
t = cookie.unchecked_remove(_COOKIE_SIM_TYPE)
oc = _client()
try:
oc.fetch_token(
authorization_response=flask.request.url,
state=s,
)
except authlib.oauth2.rfc6749.errors.MismatchingStateException:
auth.login_fail_redirect(t, this_module, 'oauth-state', reload_js=True)
raise AssertionError('auth.login_fail_redirect returned unexpectedly')
oc, t = sirepo.oauth.check_authorized_callback(github_auth=True)
d = oc.get('https://api.github.com/user').json()
sirepo.events.emit('github_authorized', PKDict(user_name=d['login']))
with util.THREAD_LOCK:
Expand All @@ -71,29 +49,27 @@ def api_authGithubAuthorized(self):
else:
u = AuthGithubUser(oauth_id=d['id'], user_name=d['login'])
u.save()
auth.login(this_module, model=u, sim_type=t, want_redirect=True)
auth.login(
pkinspect.this_module(),
model=u,
sim_type=t,
want_redirect=True,
)
raise AssertionError('auth.login returned unexpectedly')


@api_perm.require_cookie_sentinel
def api_authGithubLogin(self, simulation_type):
"""Redirects to Github"""
req = self.parse_params(type=simulation_type)
s = util.random_base62()
cookie.set_value(_COOKIE_NONCE, s)
cookie.set_value(_COOKIE_SIM_TYPE, req.type)
if not cfg.callback_uri:
# must be executed in an app and request context so can't
# initialize earlier.
cfg.callback_uri = uri_router.uri_for_api('authGithubAuthorized')
u, _ = _client().create_authorization_url(
'https://github.com/login/oauth/authorize',
redirect_uri=cfg.callback_uri,
state=s,
)
return http_reply.gen_redirect(u)


import sirepo.oauth
raise sirepo.util.Redirect(sirepo.oauth.create_authorize_redirect(
self.parse_params(
type=simulation_type,
).type,
github_auth=True,
))


@api_perm.allow_cookieless_set_user
def api_oauthAuthorized(self, oauth_type):
"""Deprecated use `api_authGithubAuthorized`"""
Expand All @@ -107,19 +83,6 @@ def avatar_uri(model, size):
)


def _client(token=None):
"""Makes it easier to mock, see github_srunit.py"""
# OAuth2Session doesn't inherit from OAuth2Mixin for some reason.
# So, supplying api_base_url has no effect.
return authlib.integrations.requests_client.OAuth2Session(
cfg.key,
cfg.secret,
scope='user:email',
token=token,
token_endpoint='https://github.com/login/oauth/access_token',
)


def _init():
def _init_model(base):
"""Creates User class bound to dynamic `db` variable"""
Expand All @@ -135,15 +98,20 @@ class AuthGithubUser(base):

global cfg, AUTH_METHOD_VISIBLE
cfg = pkconfig.init(
authorize_url=('https://github.com/login/oauth/authorize', str, 'url to redirect to for authorization'),
callback_uri=(None, str, 'Github callback URI (defaults to api_authGithubAuthorized)'),
key=pkconfig.Required(str, 'Github key'),
method_visible=(
True,
bool,
'github auth method is visible to users when it is an enabled method',
),
scope=('user:email', str, 'scope of data to request about user'),
secret=pkconfig.Required(str, 'Github secret'),
token_endpoint=('https://github.com/login/oauth/access_token', str, 'url for obtaining access token')
)
cfg.callback_api = 'authGithubAuthorized'

AUTH_METHOD_VISIBLE = cfg.method_visible
auth_db.init_model(_init_model)

Expand Down
35 changes: 30 additions & 5 deletions sirepo/auth_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
:copyright: Copyright (c) 2018-2019 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
import sqlite3

from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdc, pkdexc, pkdlog, pkdp
import contextlib
Expand All @@ -16,6 +14,7 @@
import sirepo.auth_role
import sirepo.srcontext
import sirepo.srdb
import sirepo.srtime
import sirepo.util


Expand Down Expand Up @@ -258,6 +257,7 @@ class UserRole(UserDbBase):
__tablename__ = 'user_role_t'
uid = sqlalchemy.Column(UserDbBase.STRING_ID, primary_key=True)
role = sqlalchemy.Column(UserDbBase.STRING_NAME, primary_key=True)
expiration = sqlalchemy.Column(sqlalchemy.DateTime())

@classmethod
def all_roles(cls):
Expand All @@ -267,15 +267,28 @@ def all_roles(cls):
]

@classmethod
def add_roles(cls, uid, roles):
def add_roles(cls, uid, role_or_roles, expiration=None):
if isinstance(role_or_roles, str):
role_or_roles = [role_or_roles]
with sirepo.util.THREAD_LOCK:
for r in roles:
for r in role_or_roles:
try:
UserRole(uid=uid, role=r).save()
UserRole(uid=uid, role=r, expiration=expiration).save()
except sqlalchemy.exc.IntegrityError:
pass
audit_proprietary_lib_files(uid)

@classmethod
def add_role_or_update_expiration(cls, uid, role, expiration):
assert isinstance(role, str)
with sirepo.util.THREAD_LOCK:
if not cls.has_role(uid, role):
cls.add_roles(uid, role, expiration=expiration)
return
r = cls.search_by(uid=uid, role=role)
r.expiration = expiration
r.save()

@classmethod
def delete_roles(cls, uid, roles):
with sirepo.util.THREAD_LOCK:
Expand All @@ -300,6 +313,18 @@ def has_role(cls, uid, role):
with sirepo.util.THREAD_LOCK:
return bool(cls.search_by(uid=uid, role=role))


@classmethod
def is_expired(cls, uid, role):
with sirepo.util.THREAD_LOCK:
assert cls.has_role(uid, role), \
f'No role for uid={uid} and role={role}'
r = cls.search_by(uid=uid, role=role)
if not r.expiration:
# Roles with no expiration can't expire
return False
return r.expiration < sirepo.srtime.utc_now()

@classmethod
def uids_of_paid_users(cls):
return [
Expand Down
6 changes: 5 additions & 1 deletion sirepo/auth_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
:copyright: Copyright (c) 2021 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
from __future__ import absolute_import, division, print_function
from pykern import pkconfig
from pykern.pkdebug import pkdp
import sirepo.feature_config

ROLE_ADM = 'adm'
Expand All @@ -25,6 +25,10 @@ def for_new_user(is_guest):
return []


def for_proprietary_oauth_sim_types():
return [for_sim_type(s) for s in sirepo.feature_config.cfg().proprietary_oauth_sim_types]


def for_sim_type(sim_type):
return _SIM_TYPE_ROLE_PREFIX + sim_type

Expand Down
2 changes: 1 addition & 1 deletion sirepo/db_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _20210301_migrate_role_jupyterhub():
r in sirepo.auth_db.UserRole.all_roles():
return
for u in sirepo.auth_db.all_uids():
sirepo.auth_db.UserRole.add_roles(u, [r])
sirepo.auth_db.UserRole.add_roles(u, r)


@contextlib.contextmanager
Expand Down
2 changes: 1 addition & 1 deletion sirepo/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"""
# Limit imports
from pykern.pkcollections import PKDict
import aenum
from pykern.pkdebug import pkdp

#: Map of events to handlers. Note: this is the list of all possible events.
_MAP = PKDict(
Expand Down
9 changes: 7 additions & 2 deletions sirepo/feature_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def auth_controlled_sim_types():
frozenset: enabled sim types that require role
"""
return frozenset(
cfg().proprietary_sim_types.union(cfg().moderated_sim_types),
cfg().proprietary_sim_types.union(cfg().moderated_sim_types, cfg().proprietary_oauth_sim_types),
)


Expand Down Expand Up @@ -111,6 +111,7 @@ def b(msg, dev=False):
hide_guest_warning=b('Hide the guest warning in the UI', dev=True),
),
moderated_sim_types=(set(), set, 'codes where all users must be authorized via moderation'),
proprietary_oauth_sim_types=(set(), set, 'codes authorized through oauth'),
jspec=dict(
derbenevskrinsky_force_formula=b('Include Derbenev-Skrinsky force formula'),
),
Expand Down Expand Up @@ -142,7 +143,11 @@ def b(msg, dev=False):
PROD_FOSS_CODES if pkconfig.channel_in('prod') else FOSS_CODES
)
)
s.update(_cfg.proprietary_sim_types, _cfg.moderated_sim_types)
s.update(
_cfg.proprietary_sim_types,
_cfg.moderated_sim_types,
_cfg.proprietary_oauth_sim_types,
)
for v in _DEPENDENT_CODES:
if v[0] in s:
s.add(v[1])
Expand Down
6 changes: 3 additions & 3 deletions sirepo/github_srunit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MockOAuthClient(object):

def __init__(self, monkeypatch, user_name='joeblow'):
from pykern import pkcollections
from sirepo.auth import github
import sirepo.oauth

self.values = PKDict({
'access_token': 'xyzzy',
Expand All @@ -22,7 +22,7 @@ def __init__(self, monkeypatch, user_name='joeblow'):
login=user_name,
),
})
monkeypatch.setattr(github, '_client', self)
monkeypatch.setattr(sirepo.oauth, '_client', self)

def __call__(self, *args, **kwargs):
return self
Expand All @@ -36,7 +36,7 @@ def create_authorization_url(self, *args, **kwargs):
import sirepo.http_reply

self.values.redirect_uri = kwargs['redirect_uri']
self.values.state = kwargs['state']
self.values.state = 'xxyyzz'
return f'https://github.com/login/oauth/oauthorize?response_type=code&client_id={github.cfg.key}&redirect_uri={github.cfg.callback_uri}&state={self.values.state}', self.values.state

def fetch_token(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit e084efd

Please sign in to comment.