Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #4106: proprietary oauth #4224

Merged
merged 11 commits into from
Jun 20, 2022
9 changes: 9 additions & 0 deletions etc/run-flash.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash
set -eou pipefail

export SIREPO_FEATURE_CONFIG_PROPRIETARY_OAUTH_SIM_TYPES=flash
if [[ ! ${SIREPO_SIM_OAUTH_FLASH_KEY:-} || ! ${SIREPO_SIM_OAUTH_FLASH_SECRET:-} ]]; then
echo 'You must set $SIREPO_SIM_OAUTH_FLASH_KEY and $SIREPO_SIM_OAUTH_FLASH_SECRET' 1>&2
exit 1
fi
sirepo service http
6 changes: 4 additions & 2 deletions etc/run-jupyterhub.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ elif [[ ! $SIREPO_AUTH_METHODS =~ 'email' ]]; then
export SIREPO_AUTH_METHODS=$SIREPO_AUTH_METHODS:email
fi

if [[ ${SIREPO_AUTH_GITHUB_KEY:-} || ${SIREPO_AUTH_GITHUB_SECRET:-} ]]; then
if [[ ${SIREPO_AUTH_GITHUB_KEY:-} && ${SIREPO_AUTH_GITHUB_SECRET:-} ]]; then
if [[ ! $SIREPO_AUTH_METHODS =~ 'github' ]]; then
export SIREPO_AUTH_METHODS=$SIREPO_AUTH_METHODS:github
fi
export SIREPO_AUTH_GITHUB_METHOD_VISIBLE=0
export SIREPO_AUTH_METHODS="$SIREPO_AUTH_METHODS:github"
export SIREPO_SIM_API_JUPYTERHUBLOGIN_RS_JUPYTER_MIGRATE=1
fi

Expand Down
28 changes: 19 additions & 9 deletions sirepo/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,21 @@ def api_authCompleteRegistration(self):
_parse_display_name(self.parse_json().get('displayName')),
)
return self.reply_ok()


@api_perm.allow_visitor
def api_authState(self):
return self.reply_static_jinja(
'auth-state',
'js',
PKDict(auth_state=_auth_state()),
)


@api_perm.allow_visitor
def api_authLogout(self, simulation_type=None):
"""Set the current user as logged out.

Redirects to root simulation page.
"""
req = None
Expand Down Expand Up @@ -356,17 +356,27 @@ def _moderate(uid, role):
require_email_user()
raise sirepo.util.SRException('moderationRequest', None)

def _oauth_redirect(role):
import sirepo.oauth
raise util.Redirect(
sirepo.oauth.create_authorize_redirect(
sirepo.auth_role.sim_type(role),
)
)

if sim_type not in sirepo.feature_config.auth_controlled_sim_types():
return
u = _assert_login()
if u is None:
return
r = sirepo.auth_role.for_sim_type(sim_type)
if auth_db.UserRole.has_role(u, r):
if auth_db.UserRole.has_role(u, r) and not auth_db.UserRole.is_expired(u, r):
return
if r not in sirepo.auth_role.for_moderated_sim_types():
sirepo.util.raise_forbidden(f'uid={u} does not have access to sim_type={sim_type}')
_moderate(u, r)
elif r in sirepo.auth_role.for_proprietary_oauth_sim_types():
_oauth_redirect(r)
if r in sirepo.auth_role.for_moderated_sim_types():
_moderate(u, r)
sirepo.util.raise_forbidden(f'uid={u} does not have access to sim_type={sim_type}')


def require_email_user():
Expand Down
85 changes: 24 additions & 61 deletions sirepo/auth/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,10 @@
from sirepo import api_perm
from sirepo import auth
from sirepo import auth_db
from sirepo import cookie
from sirepo import feature_config
from sirepo import http_reply
from sirepo import uri_router
from sirepo import util
import authlib.integrations.requests_client
import authlib.oauth2.rfc6749.errors
import flask
import sirepo.api
import sirepo.events
import sirepo.oauth
import sqlalchemy


Expand All @@ -33,34 +27,14 @@

#: Well known alias for auth
UserModel = None

#: module handle
this_module = pkinspect.this_module()

#: cookie keys for github (prefix is "srag")
_COOKIE_NONCE = 'sragn'
_COOKIE_SIM_TYPE = 'srags'


class API(sirepo.api.Base):
@api_perm.allow_cookieless_set_user
def api_authGithubAuthorized(self):
"""Handle a callback from a successful OAUTH request.

Tracks oauth users in a database.
"""
# clear temporary cookie values first
s = cookie.unchecked_remove(_COOKIE_NONCE)
t = cookie.unchecked_remove(_COOKIE_SIM_TYPE)
oc = _client()
try:
oc.fetch_token(
authorization_response=flask.request.url,
state=s,
)
except authlib.oauth2.rfc6749.errors.MismatchingStateException:
auth.login_fail_redirect(t, this_module, 'oauth-state', reload_js=True)
raise AssertionError('auth.login_fail_redirect returned unexpectedly')
oc, t = sirepo.oauth.check_authorized_callback(github_auth=True)
d = oc.get('https://api.github.com/user').json()
sirepo.events.emit('github_authorized', PKDict(user_name=d['login']))
with util.THREAD_LOCK:
Expand All @@ -71,29 +45,26 @@ def api_authGithubAuthorized(self):
else:
u = AuthGithubUser(oauth_id=d['id'], user_name=d['login'])
u.save()
auth.login(this_module, model=u, sim_type=t, sapi=self, want_redirect=True)
auth.login(
pkinspect.this_module(),
model=u,
sim_type=t,
want_redirect=True,
)
raise AssertionError('auth.login returned unexpectedly')


@api_perm.require_cookie_sentinel
def api_authGithubLogin(self, simulation_type):
"""Redirects to Github"""
req = self.parse_params(type=simulation_type)
s = util.random_base62()
cookie.set_value(_COOKIE_NONCE, s)
cookie.set_value(_COOKIE_SIM_TYPE, req.type)
if not cfg.callback_uri:
# must be executed in an app and request context so can't
# initialize earlier.
cfg.callback_uri = uri_router.uri_for_api('authGithubAuthorized')
u, _ = _client().create_authorization_url(
'https://github.com/login/oauth/authorize',
redirect_uri=cfg.callback_uri,
state=s,
)
return self.reply_redirect(u)


raise util.Redirect(sirepo.oauth.create_authorize_redirect(
self.parse_params(
type=simulation_type,
).type,
github_auth=True,
))


@api_perm.allow_cookieless_set_user
def api_oauthAuthorized(self, oauth_type):
"""Deprecated use `api_authGithubAuthorized`"""
Expand All @@ -107,19 +78,6 @@ def avatar_uri(model, size):
)


def _client(token=None):
"""Makes it easier to mock, see github_srunit.py"""
# OAuth2Session doesn't inherit from OAuth2Mixin for some reason.
# So, supplying api_base_url has no effect.
return authlib.integrations.requests_client.OAuth2Session(
cfg.key,
cfg.secret,
scope='user:email',
token=token,
token_endpoint='https://github.com/login/oauth/access_token',
)


def _init():
def _init_model(base):
"""Creates User class bound to dynamic `db` variable"""
Expand All @@ -135,15 +93,20 @@ class AuthGithubUser(base):

global cfg, AUTH_METHOD_VISIBLE
cfg = pkconfig.init(
authorize_url=('https://github.com/login/oauth/authorize', str, 'url to redirect to for authorization'),
callback_uri=(None, str, 'Github callback URI (defaults to api_authGithubAuthorized)'),
key=pkconfig.Required(str, 'Github key'),
method_visible=(
True,
bool,
'github auth method is visible to users when it is an enabled method',
),
scope=('user:email', str, 'scope of data to request about user'),
secret=pkconfig.Required(str, 'Github secret'),
token_endpoint=('https://github.com/login/oauth/access_token', str, 'url for obtaining access token')
)
cfg.callback_api = 'authGithubAuthorized'

AUTH_METHOD_VISIBLE = cfg.method_visible
auth_db.init_model(_init_model)

Expand Down
61 changes: 53 additions & 8 deletions sirepo/auth_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
:copyright: Copyright (c) 2018-2019 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
import sqlite3

from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdc, pkdexc, pkdlog, pkdp
import contextlib
Expand All @@ -16,6 +14,7 @@
import sirepo.auth_role
import sirepo.srcontext
import sirepo.srdb
import sirepo.srtime
import sirepo.util


Expand Down Expand Up @@ -88,15 +87,15 @@ def _add(proprietary_code_dir, sim_type, sim_data_class):
if force or f not in e:
t.join(f).rename(l.join(f))

s = sirepo.feature_config.cfg().proprietary_sim_types
s = sirepo.feature_config.proprietary_sim_types()
if sim_types:
assert sim_types.issubset(s), \
f'sim_types={sim_types} not a subset of proprietary_sim_types={s}'
s = sim_types
for t in s:
c = sirepo.sim_data.get_class(t)
if not c.proprietary_code_tarball():
return
continue
d = sirepo.srdb.proprietary_code_dir(t)
assert d.exists(), \
f'{d} proprietary_code_dir must exist' \
Expand Down Expand Up @@ -159,6 +158,23 @@ def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

@classmethod
def add_column_if_not_exists(cls, table, column, column_type):
column_type = column_type.upper()
t = table.__table__.name
r = cls._execute_raw_sql(f'PRAGMA table_info({t})')
for c in r.all():
if not c[1] == column:
continue
assert c[2] == column_type, \
(
f'unexpected column={c} when adding column={column} of',
f' type={column_type} to table={table}',
)
return
r = cls._execute_raw_sql(f'ALTER TABLE {t} ADD {column} {column_type}')
cls._session().commit()

@classmethod
def all(cls):
with sirepo.util.THREAD_LOCK:
Expand Down Expand Up @@ -191,7 +207,7 @@ def delete_user(cls, uid):

@classmethod
def execute(cls, statement):
cls._session().execute(
return cls._session().execute(
statement.execution_options(synchronize_session='fetch')
)

Expand Down Expand Up @@ -232,6 +248,10 @@ def delete_all_for_column_by_values(cls, column, values):
))
cls._session().commit()

@classmethod
def _execute_raw_sql(cls, text):
return cls.execute(sqlalchemy.text(text + ';'))

@classmethod
def _session(cls):
return sirepo.srcontext.get(_SRCONTEXT_SESSION_KEY)
Expand All @@ -258,6 +278,7 @@ class UserRole(UserDbBase):
__tablename__ = 'user_role_t'
uid = sqlalchemy.Column(UserDbBase.STRING_ID, primary_key=True)
role = sqlalchemy.Column(UserDbBase.STRING_NAME, primary_key=True)
expiration = sqlalchemy.Column(sqlalchemy.DateTime())

@classmethod
def all_roles(cls):
Expand All @@ -267,15 +288,27 @@ def all_roles(cls):
]

@classmethod
def add_roles(cls, uid, roles):
def add_roles(cls, uid, role_or_roles, expiration=None):
if isinstance(role_or_roles, str):
role_or_roles = [role_or_roles]
with sirepo.util.THREAD_LOCK:
for r in roles:
for r in role_or_roles:
try:
UserRole(uid=uid, role=r).save()
UserRole(uid=uid, role=r, expiration=expiration).save()
except sqlalchemy.exc.IntegrityError:
pass
audit_proprietary_lib_files(uid)

@classmethod
def add_role_or_update_expiration(cls, uid, role, expiration):
with sirepo.util.THREAD_LOCK:
if not cls.has_role(uid, role):
cls.add_roles(uid, role, expiration=expiration)
return
r = cls.search_by(uid=uid, role=role)
r.expiration = expiration
r.save()

@classmethod
def delete_roles(cls, uid, roles):
with sirepo.util.THREAD_LOCK:
Expand All @@ -300,6 +333,18 @@ def has_role(cls, uid, role):
with sirepo.util.THREAD_LOCK:
return bool(cls.search_by(uid=uid, role=role))


@classmethod
def is_expired(cls, uid, role):
with sirepo.util.THREAD_LOCK:
assert cls.has_role(uid, role), \
f'No role for uid={uid} and role={role}'
r = cls.search_by(uid=uid, role=role)
if not r.expiration:
# Roles with no expiration can't expire
return False
return r.expiration < sirepo.srtime.utc_now()

@classmethod
def uids_of_paid_users(cls):
return [
Expand Down
6 changes: 5 additions & 1 deletion sirepo/auth_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
:copyright: Copyright (c) 2021 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
from __future__ import absolute_import, division, print_function
from pykern import pkconfig
from pykern.pkdebug import pkdp
import sirepo.feature_config

ROLE_ADM = 'adm'
Expand All @@ -25,6 +25,10 @@ def for_new_user(is_guest):
return []


def for_proprietary_oauth_sim_types():
return [for_sim_type(s) for s in sirepo.feature_config.cfg().proprietary_oauth_sim_types]


def for_sim_type(sim_type):
return _SIM_TYPE_ROLE_PREFIX + sim_type

Expand Down
10 changes: 9 additions & 1 deletion sirepo/db_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,15 @@ def _20210301_migrate_role_jupyterhub():
r in sirepo.auth_db.UserRole.all_roles():
return
for u in sirepo.auth_db.all_uids():
sirepo.auth_db.UserRole.add_roles(u, [r])
sirepo.auth_db.UserRole.add_roles(u, r)


def _20220609_add_expiration_column_to_user_role_t():
sirepo.auth_db.UserDbBase.add_column_if_not_exists(
sirepo.auth_db.UserRole,
'expiration',
'datetime',
)


@contextlib.contextmanager
Expand Down
Loading